1//===- LinalgOps.cpp - Implementation of the linalg operations ------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the Linalg operations.
10//
11//===----------------------------------------------------------------------===//
12
13#include "mlir/Dialect/Linalg/IR/Linalg.h"
14
15#include "mlir/AsmParser/AsmParser.h"
16#include "mlir/Dialect/Affine/IR/AffineOps.h"
17#include "mlir/Dialect/Arith/IR/Arith.h"
18#include "mlir/Dialect/Arith/Utils/Utils.h"
19#include "mlir/Dialect/Complex/IR/Complex.h"
20#include "mlir/Dialect/Math/IR/Math.h"
21#include "mlir/Dialect/MemRef/IR/MemRef.h"
22#include "mlir/Dialect/SCF/IR/SCF.h"
23#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
24#include "mlir/Dialect/Tensor/IR/Tensor.h"
25#include "mlir/Dialect/Utils/IndexingUtils.h"
26#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
27#include "mlir/Dialect/Utils/StaticValueUtils.h"
28#include "mlir/IR/AffineExprVisitor.h"
29#include "mlir/IR/AffineMap.h"
30#include "mlir/IR/BuiltinAttributes.h"
31#include "mlir/IR/BuiltinTypeInterfaces.h"
32#include "mlir/IR/Matchers.h"
33#include "mlir/IR/OpImplementation.h"
34#include "mlir/IR/OperationSupport.h"
35#include "mlir/IR/PatternMatch.h"
36#include "mlir/Interfaces/InferTypeOpInterface.h"
37
38#include "llvm/ADT/DenseMap.h"
39#include "llvm/ADT/SmallSet.h"
40#include "llvm/ADT/StringSet.h"
41#include "llvm/ADT/TypeSwitch.h"
42#include "llvm/Support/FormatVariadic.h"
43#include "llvm/Support/MathExtras.h"
44#include "llvm/Support/raw_ostream.h"
45#include <optional>
46
47using namespace mlir;
48using namespace mlir::linalg;
49
50/// Return a `memref.dim` or `tensor.dim` for the shape of `v` at `dim`.
51static OpFoldResult getDimValue(OpBuilder &builder, Location loc, Value v,
52 int64_t dim) {
53 auto type = cast<ShapedType>(v.getType());
54 if (!type.isDynamicDim(dim))
55 return builder.getIndexAttr(value: type.getDimSize(dim));
56
57 return getAsOpFoldResult(
58 val: TypeSwitch<Type, Value>(v.getType())
59 .Case<RankedTensorType>(caseFn: [&](RankedTensorType t) -> Value {
60 return builder.create<tensor::DimOp>(loc, v, dim);
61 })
62 .Case<MemRefType>(caseFn: [&](MemRefType t) -> Value {
63 return builder.create<memref::DimOp>(loc, v, dim);
64 }));
65}
66
67/// Returns a memref.subview or a tensor.extract_slice based on the type of the
68/// `source`.
69static Value getSlice(OpBuilder &b, Location loc, Value source,
70 ArrayRef<OpFoldResult> offsets,
71 ArrayRef<OpFoldResult> sizes,
72 ArrayRef<OpFoldResult> strides) {
73 return TypeSwitch<Type, Value>(source.getType())
74 .Case<RankedTensorType>(caseFn: [&](RankedTensorType t) -> Value {
75 return b.create<tensor::ExtractSliceOp>(loc, source, offsets, sizes,
76 strides);
77 })
78 .Case<MemRefType>(caseFn: [&](MemRefType type) -> Value {
79 return b.create<memref::SubViewOp>(loc, source, offsets, sizes,
80 strides);
81 })
82 .Default(defaultFn: [&](Type t) { return nullptr; });
83}
84
85//===----------------------------------------------------------------------===//
86// Helper functions
87//===----------------------------------------------------------------------===//
88
89Value linalg::createOrFoldDimOp(OpBuilder &b, Location loc, Value source,
90 int64_t dim) {
91 if (llvm::isa<UnrankedMemRefType, MemRefType>(source.getType()))
92 return b.createOrFold<memref::DimOp>(loc, source, dim);
93 if (llvm::isa<UnrankedTensorType, RankedTensorType>(source.getType()))
94 return b.createOrFold<tensor::DimOp>(loc, source, dim);
95 llvm_unreachable("Expected MemRefType or TensorType");
96}
97
98OpFoldResult linalg::createFoldedDimOp(OpBuilder &b, Location loc, Value source,
99 int64_t dim) {
100 auto shapedType = llvm::cast<ShapedType>(source.getType());
101 if (!shapedType.hasRank() || shapedType.isDynamicDim(dim))
102 return createOrFoldDimOp(b, loc, source, dim);
103 return b.getIndexAttr(value: shapedType.getDimSize(dim));
104}
105
106//===----------------------------------------------------------------------===//
107// Support for named Linalg ops defined in ods-gen.
108//===----------------------------------------------------------------------===//
109
110using RegionBuilderFn = llvm::function_ref<void(ImplicitLocOpBuilder &, Block &,
111 ArrayRef<NamedAttribute>)>;
112
113/// Fills the region of a structured operation using the provided
114/// `regionBuilder`. The method is used by both named structured ops created by
115/// ods-gen and by manually defined C++ ops. It is called by both builders and
116/// parsers and creates a block with arguments corresponding to the elemental
117/// types of `inputTypes` and `outputTypes`. All output types are asserted to be
118/// ShapedType.
119static void fillStructuredOpRegion(OpBuilder &opBuilder, Region &region,
120 TypeRange inputTypes, TypeRange outputTypes,
121 ArrayRef<NamedAttribute> attrs,
122 RegionBuilderFn regionBuilder) {
123 assert(llvm::all_of(outputTypes, llvm::IsaPred<ShapedType>));
124
125 SmallVector<Type, 8> argTypes;
126 SmallVector<Location, 8> argLocs;
127 for (auto containers : {inputTypes, outputTypes}) {
128 for (auto t : containers) {
129 argTypes.push_back(
130 Elt: isa<MemRefType, RankedTensorType>(Val: t) ? getElementTypeOrSelf(type: t) : t);
131
132 // TODO: Pass in a proper location here.
133 argLocs.push_back(Elt: opBuilder.getUnknownLoc());
134 }
135 }
136
137 // RAII.
138 OpBuilder::InsertionGuard guard(opBuilder);
139 Block *body =
140 opBuilder.createBlock(parent: &region, /*insertPt=*/{}, argTypes, locs: argLocs);
141
142 opBuilder.setInsertionPointToStart(body);
143 ImplicitLocOpBuilder b(opBuilder.getUnknownLoc(), opBuilder);
144 regionBuilder(b, *body, attrs);
145
146 // indexing_maps is an auto-generated method.
147
148 // iterator_types is an auto-generated method.
149}
150
151/// Creates a structured operation given `inputs`, `outputs`, and `attributes`.
152/// The result types are derived automatically if `resultTensorTypes` is none.
153/// The body of the operation is filled using `regionBuilder`. All ods-gen
154/// created structured operations use the method to implement their builders.
155static void buildStructuredOp(OpBuilder &b, OperationState &state,
156 std::optional<TypeRange> resultTensorTypes,
157 ValueRange inputs, ValueRange outputs,
158 ArrayRef<NamedAttribute> attributes,
159 RegionBuilderFn regionBuilder) {
160 // Derive the result types if needed.
161 SmallVector<Type> derivedResultTypes =
162 resultTensorTypes.value_or(u: TypeRange());
163 if (!resultTensorTypes)
164 copy_if(Range: outputs.getTypes(), Out: std::back_inserter(x&: derivedResultTypes),
165 P: llvm::IsaPred<RankedTensorType>);
166
167 state.addOperands(newOperands: inputs);
168 state.addOperands(newOperands: outputs);
169 state.addTypes(newTypes: derivedResultTypes);
170 state.addAttributes(newAttributes: attributes);
171 state.addAttribute(
172 "operandSegmentSizes",
173 b.getDenseI32ArrayAttr({static_cast<int32_t>(inputs.size()),
174 static_cast<int32_t>(outputs.size())}));
175
176 // Create and fill the region of the structured operation.
177 Region &region = *state.addRegion();
178 fillStructuredOpRegion(opBuilder&: b, region, inputTypes: TypeRange(inputs), outputTypes: TypeRange(outputs),
179 attrs: state.attributes.getAttrs(), regionBuilder);
180}
181
182/// Common parsing used for both named structured ops created by ods-gen and by
183/// manually defined C++ ops. Does not handle regions.
184static ParseResult
185parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result,
186 SmallVectorImpl<Type> &inputTypes,
187 SmallVectorImpl<Type> &outputTypes,
188 bool addOperandSegmentSizes = true) {
189 SMLoc attrsLoc, inputsOperandsLoc, outputsOperandsLoc;
190 SmallVector<OpAsmParser::UnresolvedOperand, 4> inputsOperands,
191 outputsOperands;
192
193 if (succeeded(result: parser.parseOptionalLess())) {
194 if (parser.parseAttribute(result&: result.propertiesAttr) || parser.parseGreater())
195 return failure();
196 }
197 attrsLoc = parser.getCurrentLocation();
198 if (parser.parseOptionalAttrDict(result&: result.attributes))
199 return failure();
200
201 if (succeeded(result: parser.parseOptionalKeyword(keyword: "ins"))) {
202 if (parser.parseLParen())
203 return failure();
204
205 inputsOperandsLoc = parser.getCurrentLocation();
206 if (parser.parseOperandList(result&: inputsOperands) ||
207 parser.parseColonTypeList(result&: inputTypes) || parser.parseRParen())
208 return failure();
209 }
210
211 if (succeeded(result: parser.parseOptionalKeyword(keyword: "outs"))) {
212 outputsOperandsLoc = parser.getCurrentLocation();
213 if (parser.parseLParen() || parser.parseOperandList(result&: outputsOperands) ||
214 parser.parseColonTypeList(result&: outputTypes) || parser.parseRParen())
215 return failure();
216 }
217
218 if (parser.resolveOperands(operands&: inputsOperands, types&: inputTypes, loc: inputsOperandsLoc,
219 result&: result.operands) ||
220 parser.resolveOperands(operands&: outputsOperands, types&: outputTypes, loc: outputsOperandsLoc,
221 result&: result.operands))
222 return failure();
223
224 if (addOperandSegmentSizes) {
225 // This is a bit complex because we're trying to be backward compatible with
226 // operation syntax that mix the inherent attributes and the discardable
227 // ones in the same dictionary. If the properties are used, we append the
228 // operandSegmentSizes there directly. Otherwise we append it to the
229 // discardable attributes dictionary where it is handled by the generic
230 // Operation::create(...) method.
231 if (result.propertiesAttr) {
232 NamedAttrList attrs = llvm::cast<DictionaryAttr>(result.propertiesAttr);
233 attrs.append("operandSegmentSizes",
234 parser.getBuilder().getDenseI32ArrayAttr(
235 {static_cast<int32_t>(inputsOperands.size()),
236 static_cast<int32_t>(outputsOperands.size())}));
237 result.propertiesAttr = attrs.getDictionary(parser.getContext());
238 } else {
239 result.addAttribute("operandSegmentSizes",
240 parser.getBuilder().getDenseI32ArrayAttr(
241 {static_cast<int32_t>(inputsOperands.size()),
242 static_cast<int32_t>(outputsOperands.size())}));
243 }
244 }
245 if (!result.propertiesAttr) {
246 std::optional<RegisteredOperationName> info =
247 result.name.getRegisteredInfo();
248 if (info) {
249 if (failed(result: info->verifyInherentAttrs(attributes&: result.attributes, emitError: [&]() {
250 return parser.emitError(loc: attrsLoc)
251 << "'" << result.name.getStringRef() << "' op ";
252 })))
253 return failure();
254 }
255 }
256 return success();
257}
258
259static void printCommonStructuredOpParts(OpAsmPrinter &p, ValueRange inputs,
260 ValueRange outputs) {
261 if (!inputs.empty())
262 p << " ins(" << inputs << " : " << inputs.getTypes() << ")";
263 if (!outputs.empty())
264 p << " outs(" << outputs << " : " << outputs.getTypes() << ")";
265}
266
267//===----------------------------------------------------------------------===//
268// Specific parsing and printing for named structured ops created by ods-gen.
269//===----------------------------------------------------------------------===//
270
271static ParseResult parseNamedStructuredOpRegion(
272 OpAsmParser &parser, Region &region, unsigned numRegionArgs,
273 TypeRange inputTypes, TypeRange outputTypes, ArrayRef<NamedAttribute> attrs,
274 RegionBuilderFn regionBuilder) {
275 if (numRegionArgs != inputTypes.size() + outputTypes.size()) {
276 return parser.emitError(
277 loc: parser.getCurrentLocation(),
278 message: llvm::formatv(Fmt: "[parseNamedStructuredOpRegion] ods-gen generated "
279 "region expects {0} args, got {1}",
280 Vals&: numRegionArgs, Vals: inputTypes.size() + outputTypes.size()));
281 }
282
283 OpBuilder opBuilder(parser.getContext());
284 fillStructuredOpRegion(opBuilder, region, inputTypes, outputTypes, attrs,
285 regionBuilder);
286 return success();
287}
288
289static ParseResult
290parseNamedStructuredOpResults(OpAsmParser &parser,
291 SmallVectorImpl<Type> &resultTypes) {
292 if (parser.parseOptionalArrowTypeList(result&: resultTypes))
293 return failure();
294 return success();
295}
296
297static ParseResult parseNamedStructuredOp(OpAsmParser &parser,
298 OperationState &result,
299 unsigned numRegionArgs,
300 RegionBuilderFn regionBuilder) {
301 // TODO: Enable when ods-gen supports captures.
302 SmallVector<Type, 1> inputTypes, outputTypes;
303 if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes))
304 return failure();
305
306 // TODO: consider merging results parsing into region parsing.
307 // Need to wait for declarative assembly resolution to decide.
308 SmallVector<Type, 1> outputTensorsTypes;
309 if (parseNamedStructuredOpResults(parser, resultTypes&: outputTensorsTypes))
310 return failure();
311 result.addTypes(newTypes: outputTensorsTypes);
312
313 std::unique_ptr<Region> region = std::make_unique<Region>();
314 if (parseNamedStructuredOpRegion(parser, region&: *region, numRegionArgs, inputTypes,
315 outputTypes, attrs: result.attributes.getAttrs(),
316 regionBuilder))
317 return failure();
318 result.addRegion(region: std::move(region));
319
320 return success();
321}
322
323static void printNamedStructuredOpResults(OpAsmPrinter &p,
324 TypeRange resultTypes) {
325 if (resultTypes.empty())
326 return;
327 p.printOptionalArrowTypeList(types&: resultTypes);
328}
329
330static void printNamedStructuredOp(OpAsmPrinter &p, Operation *op,
331 ValueRange inputs, ValueRange outputs) {
332 p.printOptionalAttrDict(
333 attrs: op->getAttrs(),
334 /*elidedAttrs=*/{"operandSegmentSizes",
335 // See generated code in
336 // LinalgNamedStructuredOps.yamlgen.cpp.inc
337 "linalg.memoized_indexing_maps"});
338
339 // Printing is shared with generic ops, except for the region and
340 // attributes.
341 printCommonStructuredOpParts(p, inputs, outputs);
342
343 // Results printing.
344 printNamedStructuredOpResults(p, resultTypes: op->getResultTypes());
345
346 // Region is elided.
347}
348
349//===----------------------------------------------------------------------===//
350// Region builder helper.
351// TODO: Move this to a utility library.
352// The public methods on this class are referenced directly from generated code.
353// Helper build the unary, binary, and type conversion functions defined by the
354// DSL. See LinalgNamedStructuredOps.yamlgen.cpp.inc for the code that uses this
355// class.
356//
357// Implementations of the math functions must be polymorphic over numeric types,
358// internally performing necessary casts. If the function application makes no
359// sense, then the only recourse is to assert and return nullptr. This can be
360// extended later if it becomes possible to fail construction of the region. The
361// invariant should be enforced at a higher level.
362//
363// TODO: These helpers are currently type polymorphic over the class of integer
364// and floating point types, but they will not internally cast within bit
365// widths of a class (mixed precision such as i8->i32) or across classes
366// (i.e. mixed float and integer). Many such combinations are ambiguous or need
367// to be handled with care and work is being considered to extend the op
368// language to make such cases explicit. In the mean-time, violating this will
369// fail verification, which is deemed acceptable.
370//===----------------------------------------------------------------------===//
371
372namespace {
373
374class RegionBuilderHelper {
375public:
376 RegionBuilderHelper(OpBuilder &builder, Block &block)
377 : builder(builder), block(block) {}
378
379 // Build the unary functions defined by OpDSL.
380 Value buildUnaryFn(UnaryFn unaryFn, Value arg) {
381 if (!isFloatingPoint(value: arg))
382 llvm_unreachable("unsupported non numeric type");
383 OpBuilder::InsertionGuard g(builder);
384 builder.setInsertionPointToEnd(&block);
385 switch (unaryFn) {
386 case UnaryFn::exp:
387 return builder.create<math::ExpOp>(arg.getLoc(), arg);
388 case UnaryFn::log:
389 return builder.create<math::LogOp>(arg.getLoc(), arg);
390 case UnaryFn::abs:
391 return builder.create<math::AbsFOp>(arg.getLoc(), arg);
392 case UnaryFn::ceil:
393 return builder.create<math::CeilOp>(arg.getLoc(), arg);
394 case UnaryFn::floor:
395 return builder.create<math::FloorOp>(arg.getLoc(), arg);
396 case UnaryFn::negf:
397 return builder.create<arith::NegFOp>(arg.getLoc(), arg);
398 }
399 llvm_unreachable("unsupported unary function");
400 }
401
402 // Build the binary functions defined by OpDSL.
403 Value buildBinaryFn(BinaryFn binaryFn, Value arg0, Value arg1) {
404 bool allComplex = isComplex(value: arg0) && isComplex(value: arg1);
405 bool allFloatingPoint = isFloatingPoint(value: arg0) && isFloatingPoint(value: arg1);
406 bool allInteger = isInteger(value: arg0) && isInteger(value: arg1);
407 bool allBool = allInteger && arg0.getType().getIntOrFloatBitWidth() == 1 &&
408 arg1.getType().getIntOrFloatBitWidth() == 1;
409 if (!allComplex && !allFloatingPoint && !allInteger)
410 llvm_unreachable("unsupported non numeric type");
411 OpBuilder::InsertionGuard g(builder);
412 builder.setInsertionPointToEnd(&block);
413 switch (binaryFn) {
414 case BinaryFn::add:
415 if (allComplex)
416 return builder.create<complex::AddOp>(arg0.getLoc(), arg0, arg1);
417 if (allFloatingPoint)
418 return builder.create<arith::AddFOp>(arg0.getLoc(), arg0, arg1);
419 if (allBool)
420 return builder.create<arith::OrIOp>(arg0.getLoc(), arg0, arg1);
421 return builder.create<arith::AddIOp>(arg0.getLoc(), arg0, arg1);
422 case BinaryFn::sub:
423 if (allComplex)
424 return builder.create<complex::SubOp>(arg0.getLoc(), arg0, arg1);
425 if (allFloatingPoint)
426 return builder.create<arith::SubFOp>(arg0.getLoc(), arg0, arg1);
427 if (allBool)
428 llvm_unreachable("unsupported operation: sub with bools");
429 return builder.create<arith::SubIOp>(arg0.getLoc(), arg0, arg1);
430 case BinaryFn::mul:
431 if (allComplex)
432 return builder.create<complex::MulOp>(arg0.getLoc(), arg0, arg1);
433 if (allFloatingPoint)
434 return builder.create<arith::MulFOp>(arg0.getLoc(), arg0, arg1);
435 if (allBool)
436 return builder.create<arith::AndIOp>(arg0.getLoc(), arg0, arg1);
437 return builder.create<arith::MulIOp>(arg0.getLoc(), arg0, arg1);
438 case BinaryFn::div:
439 if (allComplex)
440 return builder.create<complex::DivOp>(arg0.getLoc(), arg0, arg1);
441 if (allFloatingPoint)
442 return builder.create<arith::DivFOp>(arg0.getLoc(), arg0, arg1);
443 if (allBool)
444 llvm_unreachable("unsupported operation: div with bools");
445 return builder.create<arith::DivSIOp>(arg0.getLoc(), arg0, arg1);
446 case BinaryFn::div_unsigned:
447 if (!allInteger || allBool)
448 llvm_unreachable("unsupported operation: unsigned div not on uint");
449 return builder.create<arith::DivUIOp>(arg0.getLoc(), arg0, arg1);
450 case BinaryFn::max_signed:
451 assert(!allComplex);
452 if (allFloatingPoint)
453 return builder.create<arith::MaximumFOp>(arg0.getLoc(), arg0, arg1);
454 return builder.create<arith::MaxSIOp>(arg0.getLoc(), arg0, arg1);
455 case BinaryFn::min_signed:
456 assert(!allComplex);
457 if (allFloatingPoint)
458 return builder.create<arith::MinimumFOp>(arg0.getLoc(), arg0, arg1);
459 return builder.create<arith::MinSIOp>(arg0.getLoc(), arg0, arg1);
460 case BinaryFn::max_unsigned:
461 assert(!allComplex);
462 if (allFloatingPoint)
463 return builder.create<arith::MaximumFOp>(arg0.getLoc(), arg0, arg1);
464 return builder.create<arith::MaxUIOp>(arg0.getLoc(), arg0, arg1);
465 case BinaryFn::min_unsigned:
466 assert(!allComplex);
467 if (allFloatingPoint)
468 return builder.create<arith::MinimumFOp>(arg0.getLoc(), arg0, arg1);
469 return builder.create<arith::MinUIOp>(arg0.getLoc(), arg0, arg1);
470 }
471 llvm_unreachable("unsupported binary function");
472 }
473
474 // Build the type functions defined by OpDSL.
475 Value buildTypeFn(TypeFn typeFn, Type toType, Value operand) {
476 switch (typeFn) {
477 case TypeFn::cast_signed:
478 return cast(toType, operand, isUnsignedCast: false);
479 case TypeFn::cast_unsigned:
480 return cast(toType, operand, isUnsignedCast: true);
481 }
482 llvm_unreachable("unsupported type conversion function");
483 }
484
485 void yieldOutputs(ValueRange values) {
486 OpBuilder::InsertionGuard g(builder);
487 builder.setInsertionPointToEnd(&block);
488 Location loc = builder.getUnknownLoc();
489 builder.create<YieldOp>(loc, values);
490 }
491
492 Value constant(const std::string &value) {
493 OpBuilder::InsertionGuard g(builder);
494 builder.setInsertionPointToEnd(&block);
495 Location loc = builder.getUnknownLoc();
496 Attribute valueAttr = parseAttribute(attrStr: value, context: builder.getContext());
497 return builder.create<arith::ConstantOp>(loc, ::cast<TypedAttr>(valueAttr));
498 }
499
500 Value index(int64_t dim) {
501 OpBuilder::InsertionGuard g(builder);
502 builder.setInsertionPointToEnd(&block);
503 return builder.create<IndexOp>(builder.getUnknownLoc(), dim);
504 }
505
506 Type getIntegerType(unsigned width) {
507 return IntegerType::get(builder.getContext(), width);
508 }
509
510 Type getFloat32Type() { return Float32Type::get(builder.getContext()); }
511 Type getFloat64Type() { return Float64Type::get(builder.getContext()); }
512
513private:
514 // Generates operations to cast the given operand to a specified type.
515 // If the cast cannot be performed, a warning will be issued and the
516 // operand returned as-is (which will presumably yield a verification
517 // issue downstream).
518 Value cast(Type toType, Value operand, bool isUnsignedCast) {
519 OpBuilder::InsertionGuard g(builder);
520 builder.setInsertionPointToEnd(&block);
521 auto loc = operand.getLoc();
522 return convertScalarToDtype(b&: builder, loc, operand, toType, isUnsignedCast);
523 }
524
525 bool isComplex(Value value) {
526 return llvm::isa<ComplexType>(value.getType());
527 }
528 bool isFloatingPoint(Value value) {
529 return llvm::isa<FloatType>(Val: value.getType());
530 }
531 bool isInteger(Value value) {
532 return llvm::isa<IntegerType>(Val: value.getType());
533 }
534
535 OpBuilder &builder;
536 Block &block;
537};
538
539} // namespace
540
541//===----------------------------------------------------------------------===//
542// CopyOp
543//===----------------------------------------------------------------------===//
544
545namespace {
546
547struct EraseSelfCopy : OpRewritePattern<CopyOp> {
548 using OpRewritePattern<CopyOp>::OpRewritePattern;
549 LogicalResult matchAndRewrite(CopyOp copyOp,
550 PatternRewriter &rewriter) const override {
551 if (copyOp.getInputs() != copyOp.getOutputs())
552 return rewriter.notifyMatchFailure(copyOp, "not a self copy");
553 if (copyOp.hasPureBufferSemantics())
554 rewriter.eraseOp(op: copyOp);
555 else
556 rewriter.replaceOp(copyOp, copyOp.getInputs());
557
558 return success();
559 }
560};
561
562} // namespace
563
564void CopyOp::getCanonicalizationPatterns(RewritePatternSet &results,
565 MLIRContext *context) {
566 results.add<EraseSelfCopy>(context);
567}
568
569//===----------------------------------------------------------------------===//
570// FillOp
571//===----------------------------------------------------------------------===//
572
573namespace {
574
575/// Fold linalg.fill -> tensor.expand/collapse_shape chain.
576///
577/// For such op chains, we can create new linalg.fill ops with the result
578/// type of the tensor.expand/collapse_shape op.
579template <typename TensorReshapeOp>
580struct FoldFillWithTensorReshape : OpRewritePattern<TensorReshapeOp> {
581 using OpRewritePattern<TensorReshapeOp>::OpRewritePattern;
582 LogicalResult matchAndRewrite(TensorReshapeOp reshapeOp,
583 PatternRewriter &rewriter) const override {
584 auto oldFill = reshapeOp.getSrc().template getDefiningOp<FillOp>();
585 if (!oldFill)
586 return failure();
587
588 Location loc = oldFill.getLoc();
589 auto newInit = rewriter.create<TensorReshapeOp>(
590 loc, reshapeOp.getResultType(), oldFill.output(),
591 reshapeOp.getReassociation());
592 rewriter.replaceOpWithNewOp<FillOp>(reshapeOp, ValueRange{oldFill.value()},
593 ValueRange{newInit});
594
595 return success();
596 }
597};
598
599/// Fold tensor.pad(linalg.fill) into linalg.fill if the padding value and the
600/// filling value are the same.
601struct FoldFillWithPad final : public OpRewritePattern<tensor::PadOp> {
602 using OpRewritePattern::OpRewritePattern;
603
604 LogicalResult matchAndRewrite(tensor::PadOp padOp,
605 PatternRewriter &rewriter) const override {
606 auto fillOp = padOp.getSource().getDefiningOp<linalg::FillOp>();
607 if (!fillOp)
608 return failure();
609
610 // We can only fold if the padding value is the same as the original
611 // filling value.
612 Value padValue = padOp.getConstantPaddingValue();
613 if (!padValue || fillOp.value() != padValue)
614 return failure();
615
616 ReifiedRankedShapedTypeDims reifiedShape;
617 if (failed(reifyResultShapes(rewriter, padOp, reifiedShape)))
618 return rewriter.notifyMatchFailure(
619 padOp, "failed to reify tensor.pad op result shape");
620
621 auto emptyTensor = rewriter.create<tensor::EmptyOp>(
622 padOp.getLoc(), reifiedShape.front(),
623 padOp.getResultType().getElementType());
624 Value replacement =
625 rewriter
626 .create<FillOp>(fillOp.getLoc(), ValueRange{padValue},
627 ValueRange{emptyTensor})
628 .getResult(0);
629 if (replacement.getType() != padOp.getResultType()) {
630 replacement = rewriter.create<tensor::CastOp>(
631 fillOp.getLoc(), padOp.getResultType(), replacement);
632 }
633 rewriter.replaceOp(padOp, replacement);
634 return success();
635 }
636};
637
638/// Fold tensor.insert_slice(tensor.pad(<input>), linalg.fill) into
639/// tensor.insert_slice(<input>, linalg.fill) if the padding value and the
640/// filling value are the same.
641struct FoldInsertPadIntoFill : public OpRewritePattern<tensor::InsertSliceOp> {
642 using OpRewritePattern::OpRewritePattern;
643
644 LogicalResult matchAndRewrite(tensor::InsertSliceOp insertOp,
645 PatternRewriter &rewriter) const override {
646 auto srcPadOp = insertOp.getSource().getDefiningOp<tensor::PadOp>();
647 if (!srcPadOp)
648 return failure();
649
650 if (insertOp.getType().getRank() != insertOp.getSourceType().getRank())
651 return failure();
652
653 // Walk back the tensor.insert_slice chain and find the first destination
654 // value at the start of the chain.
655 Value firstDest = insertOp.getDest();
656 while (auto prevOp = firstDest.getDefiningOp<tensor::InsertSliceOp>()) {
657 if (prevOp.getType().getRank() != prevOp.getSourceType().getRank())
658 return failure();
659
660 // Make sure the range of values accessed are disjoint. Without this, we
661 // cannot fold tensor.pad away.
662 bool disjoint = false;
663 for (int i = 0, e = prevOp.getType().getRank(); i < e; ++i) {
664 // If the dimension has dynamic offset/size, we cannot guarantee
665 // disjoint. So just skip it.
666 if (insertOp.isDynamicOffset(i) || insertOp.isDynamicSize(i) ||
667 insertOp.isDynamicStride(i) || prevOp.isDynamicOffset(i) ||
668 prevOp.isDynamicSize(i) || prevOp.isDynamicStride(i))
669 continue;
670
671 // Get the range start and end, inclusively for both.
672 int64_t prevStart = prevOp.getStaticOffset(i);
673 int64_t prevEnd = prevStart + (prevOp.getStaticSize(i) - 1) *
674 prevOp.getStaticStride(i);
675 int64_t nextStart = insertOp.getStaticOffset(i);
676 int64_t nextEnd = nextStart + (insertOp.getStaticSize(i) - 1) *
677 insertOp.getStaticStride(i);
678 if (prevEnd < nextStart || nextEnd < prevStart) {
679 disjoint = true;
680 break;
681 }
682 }
683
684 if (!disjoint)
685 break;
686 firstDest = prevOp.getDest();
687 }
688
689 // Check whether the first destination is a fill op. For overlapped cases,
690 // this also cannot be true.
691 auto dstFillOp = firstDest.getDefiningOp<linalg::FillOp>();
692 if (!dstFillOp)
693 return failure();
694
695 // We can only fold if the padding value is the same as the original
696 // filling value.
697 Value padValue = srcPadOp.getConstantPaddingValue();
698 if (!padValue || dstFillOp.value() != padValue)
699 return failure();
700
701 SmallVector<OpFoldResult> lowPads = srcPadOp.getMixedLowPad();
702 SmallVector<OpFoldResult> oldOffsets = insertOp.getMixedOffsets();
703
704 Location loc = insertOp.getLoc();
705 MLIRContext *context = getContext();
706
707 AffineExpr sym0, sym1;
708 bindSymbols(ctx: context, exprs&: sym0, exprs&: sym1);
709 auto addMap = AffineMap::get(dimCount: 0, symbolCount: 2, results: {sym0 + sym1}, context);
710
711 // Calculate the new offsets for the insert. It should be the old offsets
712 // plus low padding sizes.
713 SmallVector<OpFoldResult, 4> newOffsets;
714 for (const auto &p : llvm::zip(lowPads, oldOffsets)) {
715 newOffsets.push_back(affine::makeComposedFoldedAffineApply(
716 rewriter, loc, addMap, {std::get<0>(p), std::get<1>(p)}));
717 }
718
719 RankedTensorType srcPadType = srcPadOp.getSourceType();
720 SmallVector<OpFoldResult, 4> newSizes;
721 for (int i = 0, e = srcPadType.getRank(); i < e; ++i) {
722 if (srcPadType.isDynamicDim(i)) {
723 newSizes.push_back(
724 rewriter.create<tensor::DimOp>(loc, srcPadOp.getSource(), i)
725 .getResult());
726 } else {
727 newSizes.push_back(Elt: rewriter.getIndexAttr(value: srcPadType.getDimSize(i)));
728 }
729 }
730
731 rewriter.replaceOpWithNewOp<tensor::InsertSliceOp>(
732 insertOp, srcPadOp.getSource(), insertOp.getDest(), newOffsets,
733 newSizes, insertOp.getMixedStrides());
734 return success();
735 }
736};
737
738/// Fold tensor.extract(linalg.fill(<input>)) into <input>
739struct FoldFillWithTensorExtract : public OpRewritePattern<tensor::ExtractOp> {
740public:
741 using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
742
743 LogicalResult matchAndRewrite(tensor::ExtractOp extractOp,
744 PatternRewriter &rewriter) const override {
745 // See if tensor input of tensor.extract op is the result of a linalg.fill
746 // op.
747 auto fillOp = extractOp.getTensor().getDefiningOp<linalg::FillOp>();
748 if (!fillOp)
749 return failure();
750
751 // Get scalar input operand of linalg.fill op.
752 Value extractedScalar = fillOp.getInputs()[0];
753
754 // Replace tensor.extract op with scalar value used to fill the tensor.
755 rewriter.replaceOp(extractOp, extractedScalar);
756 return success();
757 }
758};
759
760/// Folds pack(fill) into a single fill op if
761/// 1. The pack op does not have padding value, or
762/// 2. The filled value and padding value are the same.
763static FailureOr<FillOp> foldFillPackIntoFillOp(RewriterBase &rewriter,
764 tensor::PackOp packOp) {
765 auto fillOp = packOp.getSource().getDefiningOp<FillOp>();
766 if (!fillOp)
767 return failure();
768
769 if (auto paddingValue = packOp.getPaddingValue())
770 if (!isEqualConstantIntOrValue(paddingValue, fillOp.value()))
771 return failure();
772
773 Value packOpDest = packOp.getDest();
774 if (!packOpDest.hasOneUse())
775 return failure();
776
777 return rewriter.create<linalg::FillOp>(packOp.getLoc(), fillOp.getInputs(),
778 packOp.getDest());
779}
780
781/// Wrapper pattern that applies foldFillPackIntoFillOp method.
782struct FoldFillWithPack : public OpRewritePattern<tensor::PackOp> {
783public:
784 FoldFillWithPack(MLIRContext *context)
785 : OpRewritePattern<tensor::PackOp>(context) {}
786
787 LogicalResult matchAndRewrite(tensor::PackOp packOp,
788 PatternRewriter &rewriter) const override {
789 auto fillOp = foldFillPackIntoFillOp(rewriter, packOp);
790 if (failed(fillOp))
791 return failure();
792 rewriter.replaceOp(packOp, fillOp.value().result());
793 return success();
794 }
795};
796
797/// Fold fill with copy.
798struct FoldFillWithCopy : OpRewritePattern<linalg::CopyOp> {
799 using OpRewritePattern<linalg::CopyOp>::OpRewritePattern;
800
801 LogicalResult matchAndRewrite(linalg::CopyOp copyOp,
802 PatternRewriter &rewriter) const override {
803 if (auto fillOp = copyOp.getInputs().front().getDefiningOp<FillOp>()) {
804 rewriter.replaceOpWithNewOp<FillOp>(copyOp, copyOp.getResultTypes(),
805 fillOp.getInputs(),
806 copyOp.getOutputs());
807 return success();
808 }
809 if (auto fillOp = copyOp.getOutputs().front().getDefiningOp<FillOp>()) {
810 rewriter.replaceOpWithNewOp<linalg::CopyOp>(copyOp, copyOp.getInputs(),
811 fillOp.getOutputs());
812 return success();
813 }
814 return failure();
815 }
816};
817
818/// Fold fill with transpose.
819struct FoldFillWithTranspose : OpRewritePattern<linalg::TransposeOp> {
820 using OpRewritePattern<linalg::TransposeOp>::OpRewritePattern;
821
822 LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp,
823 PatternRewriter &rewriter) const override {
824 if (auto fillOp = transposeOp.getInput().getDefiningOp<FillOp>()) {
825 rewriter.replaceOpWithNewOp<FillOp>(
826 transposeOp, transposeOp.getResultTypes(), fillOp.getInputs(),
827 transposeOp.getDpsInitOperand(0)->get());
828 return success();
829 }
830 return failure();
831 }
832};
833
834} // namespace
835
836void FillOp::getCanonicalizationPatterns(RewritePatternSet &results,
837 MLIRContext *context) {
838 results
839 .add<FoldFillWithCopy, FoldFillWithTensorExtract, FoldFillWithPack,
840 FoldFillWithPad, FoldFillWithTensorReshape<tensor::CollapseShapeOp>,
841 FoldFillWithTensorReshape<tensor::ExpandShapeOp>,
842 FoldInsertPadIntoFill, FoldFillWithTranspose>(context);
843}
844
845//===----------------------------------------------------------------------===//
846// GenericOp
847//===----------------------------------------------------------------------===//
848
849static void buildGenericRegion(
850 OpBuilder &builder, Location loc, Region &region, ValueRange inputs,
851 ValueRange outputs,
852 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild) {
853 SmallVector<Type, 4> blockArgTypes;
854 SmallVector<Location, 4> blockArgLocs;
855 for (ValueRange container : {inputs, outputs}) {
856 for (Value v : container) {
857 Type t = v.getType();
858 blockArgTypes.push_back(
859 Elt: isa<MemRefType, RankedTensorType>(Val: t) ? getElementTypeOrSelf(type: t) : t);
860 blockArgLocs.push_back(Elt: v.getLoc());
861 }
862 }
863
864 OpBuilder::InsertionGuard guard(builder);
865 Block *bodyBlock =
866 builder.createBlock(parent: &region, insertPt: region.end(), argTypes: blockArgTypes, locs: blockArgLocs);
867 bodyBuild(builder, loc, bodyBlock->getArguments());
868}
869
870void GenericOp::getAsmBlockArgumentNames(Region &region,
871 OpAsmSetValueNameFn setNameFn) {
872 for (Value v : getRegionInputArgs())
873 setNameFn(v, "in");
874 for (Value v : getRegionOutputArgs())
875 setNameFn(v, "out");
876}
877
878void GenericOp::build(
879 OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
880 ValueRange inputs, ValueRange outputs, ArrayAttr indexingMaps,
881 ArrayAttr iteratorTypes, StringAttr doc, StringAttr libraryCall,
882 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
883 ArrayRef<NamedAttribute> attributes) {
884 build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps,
885 iteratorTypes, doc, libraryCall);
886 result.addAttributes(attributes);
887 if (bodyBuild)
888 buildGenericRegion(builder, result.location, *result.regions.front(),
889 inputs, outputs, bodyBuild);
890}
891
892void GenericOp::build(
893 OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
894 ValueRange inputs, ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
895 ArrayRef<utils::IteratorType> iteratorTypes, StringRef doc,
896 StringRef libraryCall,
897 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
898 ArrayRef<NamedAttribute> attributes) {
899 build(builder, result, resultTensorTypes, inputs, outputs,
900 builder.getAffineMapArrayAttr(indexingMaps),
901 builder.getArrayAttr(llvm::to_vector(llvm::map_range(
902 iteratorTypes,
903 [&](utils::IteratorType iter) -> mlir::Attribute {
904 return IteratorTypeAttr::get(builder.getContext(), iter);
905 }))),
906 doc.empty() ? StringAttr() : builder.getStringAttr(doc),
907 libraryCall.empty() ? StringAttr() : builder.getStringAttr(libraryCall),
908 bodyBuild, attributes);
909}
910
911void GenericOp::build(
912 OpBuilder &builder, OperationState &result, ValueRange inputs,
913 ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
914 ArrayRef<utils::IteratorType> iteratorTypes, StringRef doc,
915 StringRef libraryCall,
916 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
917 ArrayRef<NamedAttribute> attributes) {
918 build(builder, result, TypeRange{}, inputs, outputs, indexingMaps,
919 iteratorTypes, doc, libraryCall, bodyBuild, attributes);
920}
921
922void GenericOp::build(
923 OpBuilder &builder, OperationState &result, ValueRange inputs,
924 ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
925 ArrayRef<utils::IteratorType> iteratorTypes,
926 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
927 ArrayRef<NamedAttribute> attributes) {
928 build(builder, result, inputs, outputs, indexingMaps, iteratorTypes,
929 /*doc=*/"",
930 /*libraryCall=*/"", bodyBuild, attributes);
931}
932
933void GenericOp::build(
934 OpBuilder &builder, OperationState &result, TypeRange resultTensorTypes,
935 ValueRange inputs, ValueRange outputs, ArrayRef<AffineMap> indexingMaps,
936 ArrayRef<utils::IteratorType> iteratorTypes,
937 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
938 ArrayRef<NamedAttribute> attributes) {
939 build(builder, result, resultTensorTypes, inputs, outputs, indexingMaps,
940 iteratorTypes,
941 /*doc=*/"",
942 /*libraryCall=*/"", bodyBuild, attributes);
943}
944
945void GenericOp::print(OpAsmPrinter &p) {
946 p << " ";
947
948 // Print extra attributes.
949 auto genericAttrNames = linalgTraitAttrNames();
950
951 llvm::StringSet<> genericAttrNamesSet;
952 genericAttrNamesSet.insert(genericAttrNames.begin(), genericAttrNames.end());
953 SmallVector<NamedAttribute, 8> genericAttrs;
954 for (auto attr : (*this)->getAttrs()) {
955 if (attr.getName() == getIteratorTypesAttrName()) {
956 auto iteratorTypes =
957 llvm::cast<ArrayAttr>(attr.getValue())
958 .getAsValueRange<IteratorTypeAttr, utils::IteratorType>();
959 // Convert IteratorType enums into the string representation. This is
960 // needed, because tests still use the old format when 'iterator_types'
961 // attribute is represented as an array of strings.
962 // TODO: Remove this conversion once tests are fixed.
963 SmallVector<Attribute> iteratorTypeNames =
964 llvm::to_vector(llvm::map_range(
965 iteratorTypes, [&](utils::IteratorType t) -> Attribute {
966 return StringAttr::get(getContext(), stringifyIteratorType(t));
967 }));
968
969 genericAttrs.emplace_back(
970 getIteratorTypesAttrName(),
971 ArrayAttr::get(getContext(), iteratorTypeNames));
972 } else if (genericAttrNamesSet.count(attr.getName().strref()) > 0) {
973 genericAttrs.push_back(attr);
974 }
975 }
976 if (!genericAttrs.empty()) {
977 auto genericDictAttr = DictionaryAttr::get(getContext(), genericAttrs);
978 p << genericDictAttr;
979 }
980
981 // Printing is shared with named ops, except for the region and attributes
982 printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
983
984 genericAttrNames.push_back("operandSegmentSizes");
985 genericAttrNamesSet.insert(genericAttrNames.back());
986
987 bool hasExtraAttrs = false;
988 for (NamedAttribute n : (*this)->getAttrs()) {
989 if ((hasExtraAttrs = !genericAttrNamesSet.contains(n.getName().strref())))
990 break;
991 }
992 if (hasExtraAttrs) {
993 p << " attrs = ";
994 p.printOptionalAttrDict((*this)->getAttrs(),
995 /*elidedAttrs=*/genericAttrNames);
996 }
997
998 // Print region.
999 if (!getRegion().empty()) {
1000 p << ' ';
1001 p.printRegion(getRegion());
1002 }
1003
1004 // Print results.
1005 printNamedStructuredOpResults(p, getResultTensors().getTypes());
1006}
1007
1008ParseResult GenericOp::parse(OpAsmParser &parser, OperationState &result) {
1009 DictionaryAttr dictAttr;
1010 // Parse the core linalg traits that must check into a dictAttr.
1011 // The name is unimportant as we will overwrite result.attributes.
1012 // The core linalg traits must contain the information necessary to pass the
1013 // verifier.
1014 llvm::SMLoc attributeLocation = parser.getCurrentLocation();
1015 if (parser.parseAttribute(dictAttr, "_", result.attributes))
1016 return failure();
1017 result.attributes.assign(dictAttr.getValue().begin(),
1018 dictAttr.getValue().end());
1019
1020 // Convert array of string into an array of IteratorType enums. This is
1021 // needed, because tests still use the old format when 'iterator_types'
1022 // attribute is represented as an array of strings.
1023 // TODO: Remove this conversion once tests are fixed.
1024 auto iteratorTypes = dyn_cast_or_null<ArrayAttr>(
1025 result.attributes.get(getIteratorTypesAttrName(result.name)));
1026 if (!iteratorTypes) {
1027 return parser.emitError(attributeLocation)
1028 << "expected " << getIteratorTypesAttrName(result.name)
1029 << " array attribute";
1030 }
1031
1032 SmallVector<Attribute> iteratorTypeAttrs;
1033
1034 for (StringRef s : iteratorTypes.getAsValueRange<StringAttr>()) {
1035 auto maybeIteratorType = utils::symbolizeIteratorType(s);
1036 if (!maybeIteratorType.has_value())
1037 return parser.emitError(parser.getCurrentLocation())
1038 << "unexpected iterator_type (" << s << ")";
1039
1040 iteratorTypeAttrs.push_back(
1041 IteratorTypeAttr::get(parser.getContext(), maybeIteratorType.value()));
1042 }
1043 result.attributes.set(getIteratorTypesAttrName(result.name),
1044 parser.getBuilder().getArrayAttr(iteratorTypeAttrs));
1045
1046 // Parsing is shared with named ops, except for the region.
1047 SmallVector<Type, 1> inputTypes, outputTypes;
1048 if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes))
1049 return failure();
1050
1051 // Optional attributes may be added.
1052 if (succeeded(parser.parseOptionalKeyword("attrs")))
1053 if (failed(parser.parseEqual()) ||
1054 failed(parser.parseOptionalAttrDict(result.attributes)))
1055 return failure();
1056
1057 std::unique_ptr<Region> region = std::make_unique<Region>();
1058 if (parser.parseRegion(*region, {}))
1059 return failure();
1060 result.addRegion(std::move(region));
1061
1062 // Generic ops may specify that a subset of its outputs are tensors. Such
1063 // outputs are specified in the result type.
1064 // TODO: may need to move output parsing before region parsing.
1065 // Need to wait for declarative assembly resolution to decide.
1066 SmallVector<Type, 1> outputTensorsTypes;
1067 if (parseNamedStructuredOpResults(parser, outputTensorsTypes))
1068 return failure();
1069 result.addTypes(outputTensorsTypes);
1070
1071 return success();
1072}
1073
1074static void getGenericEffectsImpl(
1075 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1076 &effects,
1077 ValueRange results, const ValueRange inputOperands,
1078 ValueRange outputOperands) {
1079 for (auto operand : inputOperands) {
1080 if (!llvm::isa<MemRefType>(Val: operand.getType()))
1081 continue;
1082 effects.emplace_back(Args: MemoryEffects::Read::get(), Args&: operand,
1083 Args: SideEffects::DefaultResource::get());
1084 }
1085 for (auto operand : outputOperands) {
1086 if (!llvm::isa<MemRefType>(Val: operand.getType()))
1087 continue;
1088 effects.emplace_back(Args: MemoryEffects::Read::get(), Args&: operand,
1089 Args: SideEffects::DefaultResource::get());
1090 effects.emplace_back(Args: MemoryEffects::Write::get(), Args&: operand,
1091 Args: SideEffects::DefaultResource::get());
1092 }
1093}
1094
1095void GenericOp::getEffects(
1096 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1097 &effects) {
1098 getGenericEffectsImpl(effects, getOperation()->getResults(), getDpsInputs(),
1099 getDpsInits());
1100}
1101
1102LogicalResult GenericOp::verify() { return success(); }
1103
1104namespace {
1105
1106/// Remove any linalg operation (on tensors) that are just copying
1107/// the values from inputs to the results. Requirements are
1108/// 1) All iterator types are parallel
1109/// 2) The body contains just a yield operation with the yielded values being
1110/// the arguments corresponding to the operands.
1111template <typename OpTy>
1112struct EraseIdentityLinalgOp : public OpRewritePattern<OpTy> {
1113 using OpRewritePattern<OpTy>::OpRewritePattern;
1114
1115 LogicalResult matchAndRewrite(OpTy linalgOp,
1116 PatternRewriter &rewriter) const override {
1117 // Check all indexing maps are identity.
1118 if (llvm::any_of(linalgOp.getIndexingMapsArray(),
1119 [](AffineMap map) { return !map.isIdentity(); }))
1120 return failure();
1121
1122 // Check that the body of the linalg operation is just a linalg.yield
1123 // operation.
1124 Block &body = linalgOp->getRegion(0).front();
1125 if (!llvm::hasSingleElement(C&: body))
1126 return failure();
1127 auto yieldOp = dyn_cast<linalg::YieldOp>(body.getTerminator());
1128 if (!yieldOp)
1129 return failure();
1130
1131 // In the buffer case, we need to check exact buffer equality.
1132 if (linalgOp.hasPureBufferSemantics()) {
1133 if (linalgOp.getNumDpsInputs() == 1 && linalgOp.getNumDpsInits() == 1 &&
1134 linalgOp.getDpsInputOperand(0)->get() ==
1135 linalgOp.getDpsInitOperand(0)->get()) {
1136 rewriter.eraseOp(op: linalgOp);
1137 return success();
1138 }
1139 return failure();
1140 }
1141
1142 // Mixed semantics is not supported yet.
1143 if (!linalgOp.hasPureTensorSemantics())
1144 return failure();
1145
1146 // Get the argument number of the returned values. That is the operand
1147 // number to use for replacing uses of this operation.
1148 SmallVector<Value> returnedArgs;
1149 for (const auto &yieldVal : llvm::enumerate(yieldOp.getValues())) {
1150 auto yieldArg = llvm::dyn_cast<BlockArgument>(yieldVal.value());
1151 if (!yieldArg || yieldArg.getOwner() != &body)
1152 return failure();
1153 unsigned argumentNumber = yieldArg.getArgNumber();
1154 Value returnedArg = linalgOp->getOperand(argumentNumber);
1155 Type resultType = linalgOp->getResult(yieldVal.index()).getType();
1156 // The input can have a different type than the result, e.g. a dynamic
1157 // input dimension can be turned into a static output dimension.
1158 Type returnType = returnedArg.getType();
1159 if (returnType != resultType) {
1160 // Distinguish between sparse conversion or dense tensor casting.
1161 // TODO: unify the two ops?
1162 if (sparse_tensor::getSparseTensorEncoding(returnType) ||
1163 sparse_tensor::getSparseTensorEncoding(resultType))
1164 returnedArg = rewriter.create<sparse_tensor::ConvertOp>(
1165 linalgOp.getLoc(), resultType, returnedArg);
1166 else {
1167 if (!tensor::CastOp::areCastCompatible(returnedArg.getType(),
1168 resultType))
1169 return failure();
1170 returnedArg = rewriter.create<tensor::CastOp>(
1171 linalgOp.getLoc(), resultType, returnedArg);
1172 }
1173 }
1174 returnedArgs.push_back(returnedArg);
1175 }
1176
1177 if (returnedArgs.size() != linalgOp->getNumResults())
1178 return failure();
1179 rewriter.replaceOp(linalgOp, returnedArgs);
1180 return success();
1181 }
1182};
1183
1184} // namespace
1185
1186void GenericOp::getCanonicalizationPatterns(RewritePatternSet &results,
1187 MLIRContext *context) {
1188 results.add<EraseIdentityLinalgOp<GenericOp>>(context);
1189}
1190
1191LogicalResult GenericOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
1192 return memref::foldMemRefCast(*this);
1193}
1194
1195//===----------------------------------------------------------------------===//
1196// MapOp
1197//===----------------------------------------------------------------------===//
1198
1199static ParseResult parseDstStyleOp(
1200 OpAsmParser &parser, OperationState &result,
1201 function_ref<ParseResult(OpAsmParser &, NamedAttrList &)> parseAttrsFn =
1202 nullptr) {
1203 // Parse `ins` and `outs`.
1204 SmallVector<Type, 4> inputTypes, outputTypes;
1205 if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes,
1206 /*addOperandSegmentSizes=*/false))
1207 return failure();
1208
1209 // Add result types.
1210 for (Type outputType : outputTypes) {
1211 if (llvm::isa<RankedTensorType>(Val: outputType))
1212 result.addTypes(newTypes: outputType);
1213 }
1214
1215 // Parse required attributes.
1216 if (parseAttrsFn && failed(result: parseAttrsFn(parser, result.attributes)))
1217 return failure();
1218
1219 // Parse optional attributes.
1220 if (parser.parseOptionalAttrDict(result&: result.attributes))
1221 return failure();
1222 return success();
1223}
1224
1225void MapOp::getAsmBlockArgumentNames(Region &region,
1226 OpAsmSetValueNameFn setNameFn) {
1227 for (Value v : getRegionInputArgs())
1228 setNameFn(v, "in");
1229}
1230
1231void MapOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
1232 if (!getResults().empty())
1233 setNameFn(getResults().front(), "mapped");
1234}
1235
1236void MapOp::build(
1237 OpBuilder &builder, OperationState &result, ValueRange inputs, Value init,
1238 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1239 ArrayRef<NamedAttribute> attributes) {
1240 build(builder, result, TypeRange{}, inputs, init);
1241 result.addAttributes(attributes);
1242
1243 // Add output types for `RankedTensorType` output arguments.
1244 Type initType = init.getType();
1245 if (llvm::isa<RankedTensorType>(initType))
1246 result.addTypes(initType);
1247
1248 if (bodyBuild)
1249 buildGenericRegion(builder, result.location, *result.regions.front(),
1250 inputs, /*outputs=*/{}, bodyBuild);
1251}
1252
1253static void addBodyWithPayloadOp(OpAsmParser &parser, OperationState &result,
1254 const OperationName &payloadOpName,
1255 const NamedAttrList &payloadOpAttrs,
1256 ArrayRef<Value> operands,
1257 bool initFirst = false) {
1258 OpBuilder b(parser.getContext());
1259 Region *body = result.addRegion();
1260 Block &block = body->emplaceBlock();
1261 b.setInsertionPointToStart(&block);
1262 SmallVector<Value> bbArgs;
1263 for (auto &operand : operands) {
1264 block.addArgument(
1265 llvm::cast<ShapedType>(operand.getType()).getElementType(),
1266 b.getUnknownLoc());
1267 }
1268 SmallVector<Value> payloadOpOperands;
1269 // If initFirst flag is enabled, we consider init as the first position of
1270 // payload operands.
1271 if (initFirst) {
1272 payloadOpOperands.push_back(Elt: block.getArguments().back());
1273 for (const auto &arg : block.getArguments().drop_back())
1274 payloadOpOperands.push_back(Elt: arg);
1275 } else {
1276 payloadOpOperands = {block.getArguments().begin(),
1277 block.getArguments().end()};
1278 }
1279
1280 Operation *payloadOp = b.create(
1281 result.location, b.getStringAttr(payloadOpName.getStringRef()),
1282 payloadOpOperands,
1283 TypeRange{llvm::cast<ShapedType>(result.operands.back().getType())
1284 .getElementType()},
1285 payloadOpAttrs);
1286 b.create<YieldOp>(result.location, payloadOp->getResults());
1287}
1288
1289ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) {
1290 std::optional<OperationName> payloadOpName;
1291 NamedAttrList payloadOpAttrs;
1292 if (succeeded(parser.parseOptionalLBrace())) {
1293 FailureOr<OperationName> operationName = parser.parseCustomOperationName();
1294 if (failed(operationName))
1295 return failure();
1296 if (parser.parseOptionalAttrDict(payloadOpAttrs))
1297 return failure();
1298 payloadOpName = operationName.value();
1299 if (parser.parseRBrace())
1300 return failure();
1301 }
1302
1303 if (parseDstStyleOp(parser, result))
1304 return failure();
1305
1306 if (payloadOpName.has_value()) {
1307 addBodyWithPayloadOp(parser, result, payloadOpName.value(), payloadOpAttrs,
1308 ArrayRef(result.operands).drop_back());
1309 } else {
1310 SmallVector<OpAsmParser::Argument> regionArgs;
1311 if (parser.parseArgumentList(regionArgs, OpAsmParser::Delimiter::Paren,
1312 /*allowType=*/true, /*allowAttrs=*/true)) {
1313 return failure();
1314 }
1315 Region *body = result.addRegion();
1316 if (parser.parseRegion(*body, regionArgs))
1317 return failure();
1318 }
1319 return success();
1320}
1321
1322// Retrieve the operation from the body, if it is the only one (except
1323// yield) and if it gets the same amount of arguments as the body does.
1324// If initFirst flag is enabled, we check that init takes the first position in
1325// operands of payload.
1326static Operation *findPayloadOp(Block *body, bool initFirst = false) {
1327 if (body->getOperations().size() != 2)
1328 return nullptr;
1329 Operation &payload = body->getOperations().front();
1330 assert(isa<YieldOp>(body->getOperations().back()));
1331
1332 if (payload.getNumOperands() == 0 ||
1333 payload.getNumOperands() != body->getNumArguments())
1334 return nullptr;
1335 if (initFirst) {
1336 // check init
1337 if (payload.getOperands().back() != body->getArgument(i: 0))
1338 return nullptr;
1339 // check rest
1340 for (const auto &[operand, bbArg] :
1341 llvm::zip(t: payload.getOperands(), u: body->getArguments().drop_front())) {
1342 if (bbArg != operand)
1343 return nullptr;
1344 }
1345 } else {
1346 for (const auto &[operand, bbArg] :
1347 llvm::zip(t: payload.getOperands(), u: body->getArguments())) {
1348 if (bbArg != operand)
1349 return nullptr;
1350 }
1351 }
1352 return &payload;
1353}
1354
1355void printShortForm(OpAsmPrinter &p, Operation *payloadOp) {
1356 SmallVector<StringRef> elidedAttrs;
1357 std::string attrToElide;
1358 p << " { " << payloadOp->getName().getStringRef();
1359 for (const auto &attr : payloadOp->getAttrs()) {
1360 auto fastAttr =
1361 llvm::dyn_cast<mlir::arith::FastMathFlagsAttr>(attr.getValue());
1362 if (fastAttr && fastAttr.getValue() == mlir::arith::FastMathFlags::none) {
1363 attrToElide = attr.getName().str();
1364 elidedAttrs.push_back(Elt: attrToElide);
1365 break;
1366 }
1367 }
1368 p.printOptionalAttrDict(attrs: payloadOp->getAttrs(), elidedAttrs);
1369 p << " }";
1370}
1371
1372void MapOp::print(OpAsmPrinter &p) {
1373 Block *mapper = getBody();
1374 Operation *payloadOp = findPayloadOp(mapper);
1375 if (payloadOp) {
1376 printShortForm(p, payloadOp);
1377 }
1378
1379 printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
1380 p.printOptionalAttrDict((*this)->getAttrs());
1381
1382 if (!payloadOp) {
1383 // Print region if the payload op was not detected.
1384 p.increaseIndent();
1385 p.printNewline();
1386 p << "(";
1387 llvm::interleaveComma(mapper->getArguments(), p,
1388 [&](auto arg) { p.printRegionArgument(arg); });
1389 p << ") ";
1390
1391 p.printRegion(getMapper(), /*printEntryBlockArgs=*/false);
1392 p.decreaseIndent();
1393 }
1394}
1395
1396LogicalResult MapOp::verify() {
1397 auto *bodyBlock = getBody();
1398 auto blockArgs = bodyBlock->getArguments();
1399
1400 // Checks if the number of `inputs` match the arity of the `mapper` region.
1401 if (getInputs().size() != blockArgs.size())
1402 return emitOpError() << "expects number of operands to match the arity of "
1403 "mapper, but got: "
1404 << getInputs().size() << " and " << blockArgs.size();
1405
1406 // The parameters of mapper should all match the element type of inputs.
1407 for (const auto &[bbArgType, inputArg] :
1408 llvm::zip(bodyBlock->getArgumentTypes(), getInputs())) {
1409 auto inputElemType =
1410 llvm::cast<ShapedType>(inputArg.getType()).getElementType();
1411 if (bbArgType != inputElemType) {
1412 return emitOpError() << "expected element type of input " << inputElemType
1413 << " to match bbArg type " << bbArgType;
1414 }
1415 }
1416
1417 // The shape of each input must match the shape of the output.
1418 auto outputShape = getInit().getType().getShape();
1419 for (Type inputArgType : TypeRange{getInputs()}) {
1420 auto inputElemShape = llvm::cast<ShapedType>(inputArgType).getShape();
1421 if (inputElemShape != outputShape) {
1422 return emitOpError() << "expected shape of input (" << inputElemShape
1423 << ") to match shape of output (" << outputShape
1424 << ")";
1425 }
1426 }
1427
1428 return success();
1429}
1430
1431SmallVector<utils::IteratorType> MapOp::getIteratorTypesArray() {
1432 int64_t rank = getInit().getType().getRank();
1433 return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
1434}
1435
1436ArrayAttr MapOp::getIndexingMaps() {
1437 Builder builder(getContext());
1438 int64_t rank = getInit().getType().getRank();
1439 int64_t numIndexingMaps = getOperands().size();
1440 return builder.getAffineMapArrayAttr(SmallVector<AffineMap>(
1441 numIndexingMaps, builder.getMultiDimIdentityMap(rank)));
1442}
1443
1444void MapOp::getEffects(
1445 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1446 &effects) {
1447 getGenericEffectsImpl(effects, getOperation()->getResults(), getDpsInputs(),
1448 getDpsInits());
1449}
1450
1451//===----------------------------------------------------------------------===//
1452// ReduceOp
1453//===----------------------------------------------------------------------===//
1454
1455void ReduceOp::getAsmBlockArgumentNames(Region &region,
1456 OpAsmSetValueNameFn setNameFn) {
1457 for (Value v : getRegionInputArgs())
1458 setNameFn(v, "in");
1459 for (Value v : getRegionOutputArgs())
1460 setNameFn(v, "init");
1461}
1462
1463void ReduceOp::getAsmResultNames(
1464 function_ref<void(Value, StringRef)> setNameFn) {
1465 if (!getResults().empty())
1466 setNameFn(getResults().front(), "reduced");
1467}
1468
1469void ReduceOp::build(
1470 OpBuilder &builder, OperationState &result, ValueRange inputs,
1471 ValueRange inits, ArrayRef<int64_t> dimensions,
1472 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuild,
1473 ArrayRef<NamedAttribute> attributes) {
1474 build(builder, result, TypeRange{}, inputs, inits, dimensions);
1475 result.addAttributes(attributes);
1476
1477 // Add output types for `RankedTensorType` output arguments.
1478 for (Value init : inits) {
1479 Type initType = init.getType();
1480 if (llvm::isa<RankedTensorType>(initType))
1481 result.addTypes(initType);
1482 }
1483
1484 if (bodyBuild)
1485 buildGenericRegion(builder, result.location, *result.regions.front(),
1486 inputs, inits, bodyBuild);
1487}
1488
1489SmallVector<utils::IteratorType> ReduceOp::getIteratorTypesArray() {
1490 int64_t inputRank =
1491 llvm::cast<ShapedType>(getInputs()[0].getType()).getRank();
1492 SmallVector<utils::IteratorType> iteratorTypes(inputRank,
1493 utils::IteratorType::parallel);
1494 for (int64_t reductionDim : getDimensions())
1495 iteratorTypes[reductionDim] = utils::IteratorType::reduction;
1496 return iteratorTypes;
1497}
1498
1499ArrayAttr ReduceOp::getIndexingMaps() {
1500 int64_t inputRank =
1501 llvm::cast<ShapedType>(getInputs()[0].getType()).getRank();
1502 SmallVector<AffineMap> affineMaps(
1503 getNumDpsInputs(),
1504 AffineMap::getMultiDimIdentityMap(inputRank, getContext()));
1505 AffineMap resultMap =
1506 AffineMap::getMultiDimIdentityMap(inputRank, getContext())
1507 .dropResults(getDimensions());
1508 for (int64_t i = 0, e = getNumDpsInits(); i < e; ++i)
1509 affineMaps.push_back(resultMap);
1510 return Builder(getContext()).getAffineMapArrayAttr(affineMaps);
1511}
1512
1513void ReduceOp::getEffects(
1514 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1515 &effects) {
1516 getGenericEffectsImpl(effects, getOperation()->getResults(), getDpsInputs(),
1517 getDpsInits());
1518}
1519
1520static ParseResult parseDenseI64ArrayAttr(OpAsmParser &parser,
1521 NamedAttrList &attributes,
1522 StringRef attributeName) {
1523 if (parser.parseKeyword(keyword: attributeName) || parser.parseEqual())
1524 return failure();
1525
1526 attributes.set(attributeName, DenseI64ArrayAttr::parse(parser, Type{}));
1527 return success();
1528}
1529
1530ParseResult ReduceOp::parse(OpAsmParser &parser, OperationState &result) {
1531 std::optional<OperationName> payloadOpName;
1532 NamedAttrList payloadOpAttrs;
1533 if (succeeded(parser.parseOptionalLBrace())) {
1534 FailureOr<OperationName> operationName = parser.parseCustomOperationName();
1535 if (failed(operationName))
1536 return failure();
1537 if (parser.parseOptionalAttrDict(payloadOpAttrs))
1538 return failure();
1539 payloadOpName = operationName.value();
1540 if (parser.parseRBrace())
1541 return failure();
1542 }
1543
1544 if (parseDstStyleOp(
1545 parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
1546 return parseDenseI64ArrayAttr(parser, attributes, "dimensions");
1547 }))
1548 return failure();
1549
1550 if (payloadOpName.has_value()) {
1551 addBodyWithPayloadOp(parser, result, payloadOpName.value(), payloadOpAttrs,
1552 ArrayRef(result.operands), /*initFirst=*/true);
1553 } else {
1554 SmallVector<OpAsmParser::Argument> regionArgs;
1555 if (parser.parseArgumentList(regionArgs, OpAsmParser::Delimiter::Paren,
1556 /*allowType=*/true, /*allowAttrs=*/true)) {
1557 return failure();
1558 }
1559
1560 Region *body = result.addRegion();
1561 if (parser.parseRegion(*body, regionArgs))
1562 return failure();
1563 }
1564
1565 return success();
1566}
1567
1568static void printDenseI64ArrayAttr(OpAsmPrinter &p, StringRef attributeName,
1569 ArrayRef<int64_t> attributeValue) {
1570 p << ' ' << attributeName << " = [" << attributeValue << "] ";
1571}
1572
1573void ReduceOp::print(OpAsmPrinter &p) {
1574 Block *mapper = getBody();
1575 Operation *payloadOp = findPayloadOp(mapper, /*initFirst=*/true);
1576 if (payloadOp) {
1577 printShortForm(p, payloadOp);
1578 }
1579
1580 printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
1581 printDenseI64ArrayAttr(p, getDimensionsAttrName(), getDimensions());
1582 p.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()});
1583 if (!payloadOp) {
1584 // Print region if the payload op was not detected.
1585 p.increaseIndent();
1586 p.printNewline();
1587 p << "(";
1588 llvm::interleaveComma(mapper->getArguments(), p,
1589 [&](auto arg) { p.printRegionArgument(arg); });
1590 p << ") ";
1591
1592 p.printRegion(getCombiner(), /*printEntryBlockArgs=*/false);
1593 p.decreaseIndent();
1594 }
1595}
1596
1597LogicalResult ReduceOp::verify() {
1598 ArrayRef<int64_t> dimensionsRef = getDimensions();
1599
1600 for (int64_t i = 1; i < getNumDpsInputs(); ++i) {
1601 if (llvm::cast<ShapedType>(getInputs()[i].getType()).getShape() !=
1602 llvm::cast<ShapedType>(getInputs()[0].getType()).getShape()) {
1603 return emitOpError() << "expects all inputs to have the same shapes. "
1604 "Shape at input-index "
1605 << i
1606 << " is not equal to the shape at input-index 0.";
1607 }
1608 }
1609 for (int64_t i = 1; i < getNumDpsInits(); ++i) {
1610 if (llvm::cast<ShapedType>(getInits()[i].getType()).getShape() !=
1611 llvm::cast<ShapedType>(getInits()[0].getType()).getShape()) {
1612 return emitOpError() << "expects all outputs to have the same shapes. "
1613 "Shape at output-index "
1614 << i
1615 << " is not equal to the shape at output-index 0.";
1616 }
1617 }
1618 auto inputType = llvm::cast<ShapedType>(getInputs()[0].getType());
1619 auto initType = llvm::cast<ShapedType>(getInits()[0].getType());
1620
1621 DenseSet<int64_t> dimensionsToReduce;
1622 for (int64_t dimension : dimensionsRef) {
1623 if (dimension < 0 || dimension >= inputType.getRank()) {
1624 return emitOpError()
1625 << "dimensions for reduction should be in the range [0, "
1626 << inputType.getRank() - 1 << "].";
1627 }
1628 dimensionsToReduce.insert(dimension);
1629 }
1630
1631 auto inputDims = inputType.getShape();
1632 auto initDims = initType.getShape();
1633
1634 // Input dimensions that will be left after the reduction.
1635 SmallVector<int64_t> reducedInputDims;
1636 for (const auto &en : llvm::enumerate(inputDims)) {
1637 if (!dimensionsToReduce.count(en.index()))
1638 reducedInputDims.push_back(en.value());
1639 }
1640
1641 if (reducedInputDims.size() != static_cast<size_t>(initType.getRank())) {
1642 return emitOpError() << "number of dimensions after reduction "
1643 << reducedInputDims.size()
1644 << " doesn't match the init rank "
1645 << initType.getRank();
1646 }
1647
1648 if (reducedInputDims != initDims)
1649 return emitOpError() << "init dimensions [" << initDims
1650 << "] doesn't match input dimensions after reduction ["
1651 << reducedInputDims << "]";
1652
1653 Block *block = getBody();
1654 if (block->getNumArguments() != this->getNumOperands())
1655 return emitOpError()
1656 << "mismatching number of operands and block arguments";
1657
1658 // Check that the first block arguments match the element type of the inputs.
1659 for (auto [input, bbArg] : llvm::zip(getInputs(), block->getArguments())) {
1660 Type inputElementType =
1661 llvm::cast<ShapedType>(input.getType()).getElementType();
1662 if (inputElementType != bbArg.getType())
1663 return emitOpError()
1664 << "input element type " << inputElementType
1665 << " does not match corresponding block argument type "
1666 << bbArg.getType();
1667 }
1668
1669 // Check that the last block arguments match the element type of the outputs.
1670 for (auto [output, bbArg] : llvm::zip(
1671 getDpsInits(), block->getArguments().take_back(getNumDpsInits()))) {
1672 auto outputElementType =
1673 llvm::cast<ShapedType>(output.getType()).getElementType();
1674 if (outputElementType != bbArg.getType())
1675 return emitOpError()
1676 << "output element type " << outputElementType
1677 << " does not match corresponding block argument type "
1678 << bbArg.getType();
1679 }
1680 return success();
1681}
1682
1683//===----------------------------------------------------------------------===//
1684// TransposeOp
1685//===----------------------------------------------------------------------===//
1686
1687static void buildIdentityRegion(OpBuilder &builder, Location loc,
1688 Region &region, ValueRange inputs,
1689 ValueRange outputs) {
1690 buildGenericRegion(builder, loc, region, inputs, outputs,
1691 bodyBuild: [](OpBuilder &b, Location loc, ValueRange args) {
1692 b.create<linalg::YieldOp>(loc, args[0]);
1693 });
1694}
1695
1696void TransposeOp::build(::mlir::OpBuilder &builder,
1697 ::mlir::OperationState &result, Value input, Value init,
1698 DenseI64ArrayAttr permutation,
1699 ArrayRef<NamedAttribute> attributes) {
1700 result.addOperands(input);
1701 result.addOperands(init);
1702 result.addAttribute(getPermutationAttrName(result.name), permutation);
1703 result.addAttributes(attributes);
1704
1705 // Add output types for `RankedTensorType` output arguments.
1706 Type initType = init.getType();
1707 if (llvm::isa<RankedTensorType>(initType))
1708 result.addTypes(initType);
1709
1710 buildIdentityRegion(builder, result.location, *result.addRegion(), input,
1711 init);
1712}
1713
1714void TransposeOp::build(::mlir::OpBuilder &builder,
1715 ::mlir::OperationState &result, Value input, Value init,
1716 ArrayRef<int64_t> permutation,
1717 ArrayRef<NamedAttribute> attributes) {
1718 build(builder, result, input, init, builder.getDenseI64ArrayAttr(permutation),
1719 attributes);
1720}
1721
1722ParseResult TransposeOp::parse(OpAsmParser &parser, OperationState &result) {
1723 if (failed(parseDstStyleOp(
1724 parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
1725 return parseDenseI64ArrayAttr(parser, attributes, "permutation");
1726 })))
1727 return failure();
1728
1729 OpBuilder builder(parser.getContext());
1730 buildIdentityRegion(builder, result.location, *result.addRegion(),
1731 /*inputs=*/result.operands,
1732 /*outputs=*/{});
1733 return success();
1734}
1735
1736void TransposeOp::getAsmResultNames(
1737 function_ref<void(Value, StringRef)> setNameFn) {
1738 if (!getResults().empty())
1739 setNameFn(getResults().front(), "transposed");
1740}
1741
1742void TransposeOp::print(OpAsmPrinter &p) {
1743 printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
1744 printDenseI64ArrayAttr(p, getPermutationAttrName(), getPermutation());
1745 p.printOptionalAttrDict((*this)->getAttrs(), {getPermutationAttrName()});
1746}
1747
1748LogicalResult TransposeOp::verify() {
1749 ArrayRef<int64_t> permutationRef = getPermutation();
1750
1751 if (!isPermutationVector(permutationRef))
1752 return emitOpError("permutation is not valid");
1753
1754 auto inputType = getInput().getType();
1755 auto initType = getInit().getType();
1756
1757 int64_t rank = inputType.getRank();
1758
1759 if (rank != initType.getRank())
1760 return emitOpError() << "input rank " << rank
1761 << " does not match init rank " << initType.getRank();
1762
1763 if (rank != static_cast<int64_t>(permutationRef.size()))
1764 return emitOpError() << "size of permutation " << permutationRef.size()
1765 << " does not match the argument rank " << rank;
1766
1767 auto inputDims = inputType.getShape();
1768 auto initDims = initType.getShape();
1769
1770 for (int64_t i = 0; i < rank; ++i) {
1771 int64_t inputDim = inputDims[permutationRef[i]];
1772 int64_t initDim = initDims[i];
1773
1774 if (inputDim != initDim) {
1775 return emitOpError() << "dim(result, " << i << ") = " << initDim
1776 << " doesn't match dim(input, permutation[" << i
1777 << "]) = " << inputDim;
1778 }
1779 }
1780
1781 return success();
1782}
1783
1784SmallVector<utils::IteratorType> TransposeOp::getIteratorTypesArray() {
1785 int64_t rank = getInit().getType().getRank();
1786 return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
1787}
1788
1789ArrayAttr TransposeOp::getIndexingMaps() {
1790 Builder builder(getContext());
1791 int64_t rank = getInit().getType().getRank();
1792 return builder.getAffineMapArrayAttr(
1793 {inversePermutation(AffineMap::getPermutationMap(
1794 llvm::to_vector_of<unsigned>(getPermutation()), getContext())),
1795 builder.getMultiDimIdentityMap(rank)});
1796}
1797
1798void TransposeOp::getEffects(
1799 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1800 &effects) {
1801 getGenericEffectsImpl(effects, getOperation()->getResults(), getDpsInputs(),
1802 getDpsInits());
1803}
1804
1805LogicalResult TransposeOp::fold(FoldAdaptor adaptor,
1806 SmallVectorImpl<OpFoldResult> &result) {
1807 // Single dimension transpose.
1808 if (getPermutation().size() == 0) {
1809 result.push_back(getInput());
1810 return success();
1811 }
1812 // Identity permutation.
1813 if (isIdentityPermutation(getPermutation())) {
1814 result.push_back(getInput());
1815 return success();
1816 }
1817
1818 return failure();
1819}
1820
1821//===----------------------------------------------------------------------===//
1822// BroadcastOp
1823//===----------------------------------------------------------------------===//
1824
1825void BroadcastOp::build(::mlir::OpBuilder &builder,
1826 ::mlir::OperationState &result, Value input, Value init,
1827 DenseI64ArrayAttr dimensions,
1828 ArrayRef<NamedAttribute> attributes) {
1829 result.addOperands(input);
1830 result.addOperands(init);
1831 result.addAttribute(getDimensionsAttrName(result.name), dimensions);
1832 result.addAttributes(attributes);
1833
1834 // Add output types for `RankedTensorType` output arguments.
1835 Type initType = init.getType();
1836 if (llvm::isa<RankedTensorType>(initType))
1837 result.addTypes(initType);
1838
1839 buildIdentityRegion(builder, result.location, *result.addRegion(), input,
1840 init);
1841}
1842
1843void BroadcastOp::build(::mlir::OpBuilder &builder,
1844 ::mlir::OperationState &result, Value input, Value init,
1845 ArrayRef<int64_t> dimensions,
1846 ArrayRef<NamedAttribute> attributes) {
1847 build(builder, result, input, init, builder.getDenseI64ArrayAttr(dimensions),
1848 attributes);
1849}
1850
1851ParseResult BroadcastOp::parse(OpAsmParser &parser, OperationState &result) {
1852 if (failed(parseDstStyleOp(
1853 parser, result, [&](OpAsmParser &parser, NamedAttrList &attributes) {
1854 return parseDenseI64ArrayAttr(parser, attributes, "dimensions");
1855 })))
1856 return failure();
1857
1858 OpBuilder builder(parser.getContext());
1859 buildIdentityRegion(builder, result.location, *result.addRegion(),
1860 /*inputs=*/result.operands,
1861 /*outputs=*/{});
1862 return success();
1863}
1864
1865void BroadcastOp::getAsmResultNames(
1866 function_ref<void(Value, StringRef)> setNameFn) {
1867 if (!getResults().empty())
1868 setNameFn(getResults().front(), "broadcasted");
1869}
1870
1871void BroadcastOp::print(OpAsmPrinter &p) {
1872 printCommonStructuredOpParts(p, getDpsInputs(), getDpsInits());
1873 printDenseI64ArrayAttr(p, getDimensionsAttrName(), getDimensions());
1874 p.printOptionalAttrDict((*this)->getAttrs(), {getDimensionsAttrName()});
1875}
1876
1877LogicalResult BroadcastOp::verify() {
1878 ArrayRef<int64_t> dimensionsRef = getDimensions();
1879
1880 auto inputType = getInput().getType();
1881 auto initType = getInit().getType();
1882
1883 int64_t inputRank = inputType.getRank();
1884 int64_t initRank = initType.getRank();
1885
1886 auto inputShape = inputType.getShape();
1887 auto initShape = initType.getShape();
1888
1889 if ((size_t)inputRank + dimensionsRef.size() != (size_t)initRank)
1890 return emitOpError() << "input rank plus added dimensions does not "
1891 "match init rank. input rank: "
1892 << inputRank
1893 << ", dimensions size: " << dimensionsRef.size()
1894 << ", init rank: " << initRank;
1895
1896 for (const auto &[idx, dim] : llvm::enumerate(dimensionsRef)) {
1897 if (dim < 0 || dim >= initRank)
1898 return emitOpError() << "dimension " << idx
1899 << " is out of range. expected range: [0, "
1900 << initRank - 1 << "], got: " << dim;
1901 }
1902
1903 // Mapping from input dims to init dims.
1904 SmallVector<int64_t> dimMap;
1905 for (auto dim : llvm::seq<int64_t>(0, initRank)) {
1906 if (!llvm::is_contained(dimensionsRef, dim))
1907 dimMap.push_back(dim);
1908 }
1909
1910 for (const auto &[inputDimIdx, initDimIdx] : llvm::enumerate(dimMap)) {
1911 // This dimensions is mapped from the input. Init and input dims should
1912 // match.
1913 if (inputShape[inputDimIdx] != initShape[initDimIdx])
1914 return emitOpError() << "input dim " << inputDimIdx
1915 << " should match init dim " << initDimIdx
1916 << ". input: " << inputShape[inputDimIdx]
1917 << ", init: " << initShape[initDimIdx];
1918 }
1919
1920 return success();
1921}
1922
1923SmallVector<utils::IteratorType> BroadcastOp::getIteratorTypesArray() {
1924 int64_t rank = getInit().getType().getRank();
1925 return SmallVector<utils::IteratorType>(rank, utils::IteratorType::parallel);
1926}
1927
1928ArrayAttr BroadcastOp::getIndexingMaps() {
1929 Builder builder(getContext());
1930 int64_t rank = getInit().getType().getRank();
1931 return builder.getAffineMapArrayAttr(
1932 {builder.getMultiDimIdentityMap(rank).dropResults(getDimensions()),
1933 builder.getMultiDimIdentityMap(rank)});
1934}
1935
1936void BroadcastOp::getEffects(
1937 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
1938 &effects) {
1939 getGenericEffectsImpl(effects, getOperation()->getResults(), getDpsInputs(),
1940 getDpsInits());
1941}
1942
1943void BroadcastOp::getCanonicalizationPatterns(RewritePatternSet &results,
1944 MLIRContext *context) {
1945 results.add<EraseIdentityLinalgOp<BroadcastOp>>(context);
1946}
1947
1948//===----------------------------------------------------------------------===//
1949// YieldOp
1950//===----------------------------------------------------------------------===//
1951
1952void linalg::YieldOp::print(OpAsmPrinter &p) {
1953 if (getNumOperands() > 0)
1954 p << ' ' << getOperands();
1955 p.printOptionalAttrDict((*this)->getAttrs());
1956 if (getNumOperands() > 0)
1957 p << " : " << getOperandTypes();
1958}
1959
1960ParseResult YieldOp::parse(OpAsmParser &parser, OperationState &result) {
1961 SmallVector<OpAsmParser::UnresolvedOperand, 2> opInfo;
1962 SmallVector<Type, 2> types;
1963 SMLoc loc = parser.getCurrentLocation();
1964 return failure(parser.parseOperandList(opInfo) ||
1965 parser.parseOptionalAttrDict(result.attributes) ||
1966 (!opInfo.empty() && parser.parseColonTypeList(types)) ||
1967 parser.resolveOperands(opInfo, types, loc, result.operands));
1968}
1969
1970// Check the operand number and types must match the element types of the
1971// LinalgOp interface's shaped operands.
1972static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp) {
1973 if (op.getNumOperands() != linalgOp.getNumDpsInits())
1974 return op.emitOpError("expected number of yield values (")
1975 << op.getNumOperands()
1976 << ") to match the number of inits / outs operands of the enclosing "
1977 << "LinalgOp (" << linalgOp.getNumDpsInits() << ")";
1978
1979 for (OpOperand &opOperand : op->getOpOperands()) {
1980 OpOperand *outputOperand =
1981 linalgOp.getDpsInitOperand(opOperand.getOperandNumber());
1982 Type elementType = outputOperand->get().getType();
1983 if (isa<MemRefType, RankedTensorType>(elementType))
1984 elementType = getElementTypeOrSelf(outputOperand->get().getType());
1985 if (opOperand.get().getType() != elementType)
1986 return op.emitOpError("type of yield operand ")
1987 << (opOperand.getOperandNumber() + 1) << " ("
1988 << opOperand.get().getType() << ") doesn't match "
1989 << "the element type of the enclosing linalg.generic op ("
1990 << elementType << ")";
1991 }
1992 return success();
1993}
1994
1995LogicalResult linalg::YieldOp::verify() {
1996 auto *parentOp = (*this)->getParentOp();
1997 if (parentOp->getNumRegions() != 1 || parentOp->getRegion(0).empty())
1998 return emitOpError("expected single non-empty parent region");
1999
2000 if (auto linalgOp = dyn_cast<LinalgOp>(parentOp))
2001 return verifyYield(*this, linalgOp);
2002
2003 return emitOpError("expected parent op with LinalgOp interface");
2004}
2005
2006//===----------------------------------------------------------------------===//
2007// IndexOp
2008//===----------------------------------------------------------------------===//
2009
2010LogicalResult IndexOp::verify() {
2011 auto linalgOp = dyn_cast<LinalgOp>((*this)->getParentOp());
2012 if (!linalgOp)
2013 return emitOpError("expected parent op with LinalgOp interface");
2014 if (linalgOp.getNumLoops() <= getDim())
2015 return emitOpError("expected dim (")
2016 << getDim() << ") to be lower than the number of loops ("
2017 << linalgOp.getNumLoops() << ") of the enclosing LinalgOp";
2018 return success();
2019}
2020
2021/////// Operations corresponding to library calls defined with Tablegen ////////
2022
2023#include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yamlgen.cpp.inc"
2024
2025#define GET_OP_CLASSES
2026#include "mlir/Dialect/Linalg/IR/LinalgOps.cpp.inc"
2027
2028#define GET_OP_CLASSES
2029#include "mlir/Dialect/Linalg/IR/LinalgStructuredOps.cpp.inc"
2030
2031AffineMap mlir::linalg::extractOrIdentityMap(std::optional<AffineMap> maybeMap,
2032 unsigned rank,
2033 MLIRContext *context) {
2034 if (maybeMap)
2035 return *maybeMap;
2036 if (rank == 0)
2037 return AffineMap::get(context);
2038 return AffineMap::getMultiDimIdentityMap(numDims: rank, context);
2039}
2040
2041SmallVector<AffineExpr, 4>
2042mlir::linalg::makeAffineDimExprs(unsigned num, unsigned &startIdx,
2043 MLIRContext *context) {
2044 SmallVector<AffineExpr, 4> res;
2045 res.reserve(N: num);
2046 for (unsigned i = 0; i < num; ++i)
2047 res.push_back(Elt: getAffineDimExpr(position: startIdx++, context));
2048 return res;
2049}
2050
2051SmallVector<AffineExpr, 4> mlir::linalg::concat(ArrayRef<AffineExpr> a,
2052 ArrayRef<AffineExpr> b) {
2053 auto rangeA = llvm::make_range(x: a.begin(), y: a.end());
2054 auto rangeB = llvm::make_range(x: b.begin(), y: b.end());
2055 auto concatRanges = llvm::concat<const AffineExpr>(Ranges&: rangeA, Ranges&: rangeB);
2056 return llvm::to_vector<4>(Range&: concatRanges);
2057}
2058
2059static LogicalResult appendMangledType(llvm::raw_string_ostream &ss, Type t) {
2060 if (auto memref = llvm::dyn_cast<MemRefType>(t)) {
2061 ss << "view";
2062 for (auto size : memref.getShape())
2063 if (size < 0)
2064 ss << "sx";
2065 else
2066 ss << size << "x";
2067 if (failed(appendMangledType(ss, memref.getElementType())))
2068 return failure();
2069 if (auto as = memref.getMemorySpace()) {
2070 if (auto attr = llvm::dyn_cast<IntegerAttr>(as))
2071 ss << "as" << attr.getInt();
2072 else
2073 return failure();
2074 }
2075 return success();
2076 }
2077 if (auto vec = llvm::dyn_cast<VectorType>(t)) {
2078 ss << "vector";
2079 llvm::interleave(
2080 vec.getShape(), [&](int64_t i) { ss << i; }, [&]() { ss << "x"; });
2081 if (failed(appendMangledType(ss, vec.getElementType())))
2082 return failure();
2083 return success();
2084 }
2085 if (t.isSignlessIntOrIndexOrFloat()) {
2086 ss << t;
2087 return success();
2088 }
2089 return failure();
2090}
2091
2092std::string mlir::linalg::generateLibraryCallName(Operation *op) {
2093 assert(isa<LinalgOp>(op));
2094 std::string name(op->getName().getStringRef().str());
2095 std::string fun = "";
2096 for (NamedAttribute kv : op->getAttrs()) {
2097 if (UnaryFnAttr ufa = llvm::dyn_cast<UnaryFnAttr>(kv.getValue())) {
2098 fun = stringifyEnum(ufa.getValue()).str() + "_";
2099 } else if (BinaryFnAttr bfa = llvm::dyn_cast<BinaryFnAttr>(kv.getValue())) {
2100 fun = stringifyEnum(bfa.getValue()).str() + "_";
2101 }
2102 }
2103 name.reserve(res: 128);
2104 std::replace(first: name.begin(), last: name.end(), old_value: '.', new_value: '_');
2105 llvm::raw_string_ostream ss(name);
2106 ss << "_" << fun;
2107 for (Type t : op->getOperandTypes()) {
2108 if (failed(result: appendMangledType(ss, t)))
2109 return std::string();
2110 ss << "_";
2111 }
2112 std::string res = ss.str();
2113 res.pop_back();
2114 return res;
2115}
2116
2117//===----------------------------------------------------------------------===//
2118// Canonicalizers and Folders.
2119//===----------------------------------------------------------------------===//
2120
2121namespace {
2122struct EraseDeadLinalgOp : public OpInterfaceRewritePattern<LinalgOp> {
2123 using OpInterfaceRewritePattern<LinalgOp>::OpInterfaceRewritePattern;
2124
2125 LogicalResult matchAndRewrite(LinalgOp op,
2126 PatternRewriter &rewriter) const override {
2127 for (OpOperand &opOperand : op->getOpOperands()) {
2128 // Linalg "inputs" may be either tensor or memref type.
2129 // tensor<0xelt_type> is a convention that may not always mean
2130 // "0 iterations". Only erase in cases we see memref<...x0x...>.
2131 auto mt = llvm::dyn_cast<MemRefType>(opOperand.get().getType());
2132 if (!mt)
2133 continue;
2134 if (llvm::is_contained(op.getShape(&opOperand), 0)) {
2135 rewriter.eraseOp(op);
2136 return success();
2137 }
2138 }
2139 return failure();
2140 }
2141};
2142
2143/// Fold LinalgOps with `tensor.cast` consumer if the `tensor.cast` has
2144/// result that is more static than the linalg op.
2145struct FoldTensorCastConsumerOp : public OpRewritePattern<tensor::CastOp> {
2146 using OpRewritePattern<tensor::CastOp>::OpRewritePattern;
2147
2148 LogicalResult matchAndRewrite(tensor::CastOp castOp,
2149 PatternRewriter &rewriter) const override {
2150 if (!tensor::canFoldIntoProducerOp(castOp))
2151 return failure();
2152
2153 auto linalgOp = castOp.getSource().getDefiningOp<LinalgOp>();
2154 if (!linalgOp)
2155 return failure();
2156
2157 // Cast can be in conditionally reachable region, if which case folding will
2158 // generate invalid code. Only conservatively fold ops in same block for
2159 // now.
2160 if (castOp->getBlock() != linalgOp->getBlock())
2161 return failure();
2162
2163 OpBuilder::InsertionGuard guard(rewriter);
2164 rewriter.setInsertionPoint(linalgOp);
2165
2166 Location loc = linalgOp.getLoc();
2167 OpResult resultValue = llvm::cast<OpResult>(castOp.getSource());
2168 unsigned resultNumber = resultValue.getResultNumber();
2169 auto resultType =
2170 llvm::cast<RankedTensorType>(castOp->getResult(0).getType());
2171 // Replace the `outs` for the result with a `tensor.cast`. This cast is now
2172 // going from a more dynamic shape to a less dynamic shape. If the producer
2173 // for this cast, i.e. producer of the out operand, is also an operation
2174 // that folds with tensor.cast consumer (like this pattern), the cast will
2175 // continue to propagate as far up the stack as it can go.
2176 OpOperand *outOperand = linalgOp.getDpsInitOperand(resultNumber);
2177 Value newOperand =
2178 rewriter.create<tensor::CastOp>(loc, resultType, outOperand->get());
2179 SmallVector<Value> newOperands = linalgOp.getDpsInputs();
2180 SmallVector<Value> outputOperands(linalgOp.getDpsInits().begin(),
2181 linalgOp.getDpsInits().end());
2182 outputOperands[resultNumber] = newOperand;
2183 newOperands.append(in_start: outputOperands.begin(), in_end: outputOperands.end());
2184
2185 SmallVector<Type> resultTypes(linalgOp->result_type_begin(),
2186 linalgOp->result_type_end());
2187 resultTypes[resultNumber] = resultType;
2188 Operation *newOp = clone(rewriter, linalgOp, resultTypes, newOperands);
2189
2190 // Create a tensor.cast operation back to the original type.
2191 Value castBack = rewriter.create<tensor::CastOp>(
2192 loc, resultValue.getType(), newOp->getResult(resultNumber));
2193
2194 SmallVector<Value> results(newOp->result_begin(), newOp->result_end());
2195 results[resultNumber] = castBack;
2196 rewriter.replaceOp(linalgOp, results);
2197 rewriter.replaceOp(castOp, newOp->getResult(idx: resultNumber));
2198 return success();
2199 }
2200};
2201
2202/// For each of the operand in `operands` this function maps the static sizes of
2203/// dimensions to their affine dim expressions.
2204static void populateMap(LinalgOp linalgOp, MutableArrayRef<OpOperand> operands,
2205 llvm::DenseMap<AffineExpr, int64_t> &affineExprToSize) {
2206 for (OpOperand &opOperand : operands) {
2207 if (linalgOp.isScalar(&opOperand))
2208 continue;
2209 Value src = opOperand.get();
2210 auto sourceType = llvm::cast<RankedTensorType>(src.getType());
2211 auto sourceMap = linalgOp.getMatchingIndexingMap(&opOperand);
2212
2213 // Get the `sourceShape` of the `sourceType`. If the operand is a result of
2214 // `tensor.cast` operation and source of the cast operation has a static
2215 // shape, then assign it to the `sourceShape`.
2216 auto *parentOp = src.getDefiningOp();
2217 ArrayRef<int64_t> sourceShape = sourceType.getShape();
2218 if (parentOp) {
2219 if (auto castOp = dyn_cast<tensor::CastOp>(parentOp)) {
2220 Value castSource = castOp.getSource();
2221 auto castSourceType =
2222 llvm::dyn_cast<RankedTensorType>(castSource.getType());
2223 if (castSourceType && castSourceType.hasStaticShape())
2224 sourceShape = castSourceType.getShape();
2225 }
2226 }
2227
2228 // If the source shape's dimension has a static shape, map the affine dim
2229 // expression to the known static size.
2230 for (unsigned i = 0; i < sourceShape.size(); i++) {
2231 if (sourceType.isDynamicDim(i))
2232 continue;
2233 if (auto affineDimExpr = dyn_cast<AffineDimExpr>(sourceMap.getResult(i)))
2234 affineExprToSize.try_emplace(affineDimExpr, sourceShape[i]);
2235 }
2236 }
2237}
2238
2239/// Creates new operand w.r.t 'opOperand' of `linalgOp` with static sizes
2240/// mapped in `affineExprToSize`. New operands are created in `newOperands` and
2241/// their result types is stored in `resultTypes`. If `opOperand` requires no
2242/// change then `changeNeeded` is false and same operand is added in the
2243/// `newOperands` list.
2244static void createNewOperandWithStaticSizes(
2245 Location loc, PatternRewriter &rewriter, OpOperand *opOperand,
2246 llvm::DenseMap<AffineExpr, int64_t> &affineExprToSize, LinalgOp linalgOp,
2247 SmallVector<Value> &newOperands, SmallVector<Type> &resultTypes,
2248 bool &changeNeeded) {
2249 Value src = opOperand->get();
2250 newOperands.push_back(Elt: src);
2251 if (linalgOp.isScalar(opOperand))
2252 return;
2253 auto sourceType = llvm::cast<RankedTensorType>(src.getType());
2254 Type resultType = sourceType;
2255 if (sourceType.hasStaticShape() && linalgOp.isDpsInit(opOperand)) {
2256 resultTypes.push_back(Elt: resultType);
2257 return;
2258 }
2259 ArrayRef<int64_t> sourceShape = sourceType.getShape();
2260 AffineMap sourceMap = linalgOp.getMatchingIndexingMap(opOperand);
2261 SmallVector<int64_t> newShape;
2262 // If operand is updated with new shape, `newOperandNeeded` will be
2263 // true.
2264 bool newOperandNeeded = false;
2265 for (unsigned i = 0; i < sourceShape.size(); i++) {
2266 int64_t dimShape = sourceShape[i];
2267 AffineExpr dimExpr = sourceMap.getResult(idx: i);
2268 if (!affineExprToSize.contains(Val: dimExpr) || !sourceType.isDynamicDim(i)) {
2269 newShape.push_back(Elt: dimShape);
2270 continue;
2271 }
2272 // Dimension has a dynamic shape and corresponding affine dim
2273 // expression is present in the map. So assign the size for the
2274 // given affine dim expression to the dimension.
2275 newShape.push_back(Elt: affineExprToSize[dimExpr]);
2276 newOperandNeeded = true;
2277 }
2278 resultType = RankedTensorType::get(newShape, sourceType.getElementType());
2279 if (newOperandNeeded) {
2280 changeNeeded = true;
2281 // Get the new operand value given its size and element type by
2282 // casting it.
2283 Value newOperand = rewriter.create<tensor::CastOp>(loc, resultType, src);
2284 unsigned index = opOperand->getOperandNumber();
2285 newOperands[index] = newOperand;
2286 }
2287 if (linalgOp.isDpsInit(opOperand))
2288 resultTypes.push_back(Elt: resultType);
2289}
2290
2291/// Static shapes for the operands can be inferred if any one of the operands
2292/// have a static shape. This can be done by referring to the affine dim
2293/// expressions for the operand.
2294struct InferStaticShapeOfOperands : public OpInterfaceRewritePattern<LinalgOp> {
2295 using OpInterfaceRewritePattern<LinalgOp>::OpInterfaceRewritePattern;
2296
2297 LogicalResult matchAndRewrite(LinalgOp linalgOp,
2298 PatternRewriter &rewriter) const override {
2299 if (!linalgOp.hasPureTensorSemantics())
2300 return failure();
2301
2302 // Maps must be projected permutations.
2303 if (llvm::any_of(linalgOp.getIndexingMapsArray(), [](AffineMap map) {
2304 return !map.isProjectedPermutation();
2305 }))
2306 return failure();
2307
2308 // Maps affine dim expressions to the static size of that dimension.
2309 llvm::DenseMap<AffineExpr, int64_t> affineExprToSize;
2310 Location loc = linalgOp.getLoc();
2311
2312 // For each of the affine dim expression, check if the size is known. If
2313 // known add that in the map.
2314 populateMap(linalgOp, linalgOp->getOpOperands(), affineExprToSize);
2315
2316 SmallVector<Value> newOperands;
2317 SmallVector<Type> resultTypes;
2318
2319 // `changeNeeded` is `false` if the operands of `linalgOp` require no
2320 // change in their types.
2321 bool changeNeeded = false;
2322 newOperands.reserve(N: linalgOp->getNumOperands());
2323 resultTypes.reserve(N: linalgOp.getNumDpsInits());
2324
2325 // Iterate over all the operands and update the static sizes.
2326 for (OpOperand &opOperand : linalgOp->getOpOperands()) {
2327 createNewOperandWithStaticSizes(loc, rewriter, &opOperand,
2328 affineExprToSize, linalgOp, newOperands,
2329 resultTypes, changeNeeded);
2330 }
2331
2332 // If the generic op has all the required static information, no
2333 // canonicalization needed.
2334 if (!changeNeeded)
2335 return failure();
2336
2337 // Clone op.
2338 Operation *newOp = clone(rewriter, linalgOp, resultTypes, newOperands);
2339 SmallVector<Value> replacements;
2340 replacements.reserve(N: newOp->getNumResults());
2341 for (auto it : llvm::zip(linalgOp->getResults(), newOp->getResults())) {
2342 Value newResult = std::get<1>(it);
2343 Value oldResult = std::get<0>(it);
2344 Type newType = newResult.getType();
2345 Type oldType = oldResult.getType();
2346 replacements.push_back(
2347 (newType != oldType)
2348 ? rewriter.create<tensor::CastOp>(loc, oldType, newResult)
2349 : newResult);
2350 }
2351 rewriter.replaceOp(linalgOp, replacements);
2352 return success();
2353 }
2354};
2355
2356} // namespace
2357
2358// All named ops canonicalizers and folders are auto-generated in the
2359// .cpp.inc.
2360
2361//===----------------------------------------------------------------------===//
2362// SoftmaxOp
2363//===----------------------------------------------------------------------===//
2364
2365LogicalResult SoftmaxOp::verify() {
2366 ShapedType inputType = getInputOperandType();
2367 ShapedType outputType = getOutputOperandType();
2368
2369 ArrayRef<int64_t> inputShape = inputType.getShape();
2370 ArrayRef<int64_t> outputShape = outputType.getShape();
2371 if (failed(verifyCompatibleShape(inputShape, outputShape)))
2372 return emitOpError("incompatible output shape");
2373
2374 int64_t inputRank = getInputOperandRank();
2375 int64_t dimension = getDimension();
2376 if ((dimension < 0) || (dimension >= inputRank))
2377 return emitOpError("incorrect dimension specified");
2378
2379 return success();
2380}
2381
2382SmallVector<Range> SoftmaxOp::getIterationDomain(OpBuilder &builder) {
2383 int64_t operandRank = getInputOperandRank();
2384 SmallVector<Range> loopBounds(operandRank);
2385 Location loc = getLoc();
2386 Value zero = builder.create<arith::ConstantIndexOp>(loc, 0);
2387 Value one = builder.create<arith::ConstantIndexOp>(loc, 1);
2388 Value source = getInput();
2389 for (auto dim : llvm::seq<int64_t>(0, operandRank)) {
2390 loopBounds[dim].offset = zero;
2391 loopBounds[dim].size = getDimValue(builder, loc, source, dim);
2392 loopBounds[dim].stride = one;
2393 }
2394 return loopBounds;
2395}
2396
2397SmallVector<utils::IteratorType> SoftmaxOp::getLoopIteratorTypes() {
2398 SmallVector<utils::IteratorType> iteratorTypes(getInputOperandRank(),
2399 utils::IteratorType::parallel);
2400 iteratorTypes[getDimension()] = utils::IteratorType::reduction;
2401 return iteratorTypes;
2402}
2403
2404FailureOr<TilingResult>
2405SoftmaxOp::getTiledImplementation(OpBuilder &builder,
2406 ArrayRef<OpFoldResult> offsets,
2407 ArrayRef<OpFoldResult> sizes) {
2408 int64_t rank = getInputOperandRank();
2409 auto oneAttr = builder.getI64IntegerAttr(1);
2410 SmallVector<OpFoldResult> strides(rank, oneAttr);
2411 SmallVector<Value> tiledOperands;
2412 tiledOperands.emplace_back(
2413 getSlice(builder, getLoc(), getInput(), offsets, sizes, strides));
2414 tiledOperands.emplace_back(
2415 getSlice(builder, getLoc(), getOutput(), offsets, sizes, strides));
2416
2417 SmallVector<Type, 4> resultTypes;
2418 if (hasPureTensorSemantics())
2419 resultTypes.push_back(tiledOperands[1].getType());
2420 Operation *tiledOp =
2421 mlir::clone(builder, getOperation(), resultTypes, tiledOperands);
2422
2423 return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
2424}
2425
2426LogicalResult SoftmaxOp::getResultTilePosition(
2427 OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
2428 ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
2429 SmallVector<OpFoldResult> &resultSizes) {
2430 if (resultNumber == 0) {
2431 resultOffsets.assign(offsets.begin(), offsets.end());
2432 resultSizes.assign(sizes.begin(), sizes.end());
2433 return success();
2434 }
2435 return failure();
2436}
2437
2438// cast(dynamic) -> static.
2439LogicalResult SoftmaxOp::fold(FoldAdaptor, SmallVectorImpl<OpFoldResult> &) {
2440 return memref::foldMemRefCast(*this);
2441}
2442
2443LogicalResult
2444SoftmaxOp::reifyResultShapes(OpBuilder &b,
2445 ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
2446 SmallVector<OpFoldResult> shapes;
2447 Location loc = getOperation()->getLoc();
2448 IRRewriter rewriter(b);
2449 auto inputShapedType = llvm::cast<ShapedType>(getInputOperandType());
2450 auto outputShapedType = llvm::cast<ShapedType>(getOutputOperandType());
2451 for (int64_t dim : llvm::seq<int64_t>(0, getOutputOperandRank())) {
2452 if (!outputShapedType.isDynamicDim(dim)) {
2453 // Static dim: Return IntegerAttr.
2454 shapes.push_back(b.getIndexAttr(inputShapedType.getDimSize(dim)));
2455 } else {
2456 // Dynamic dim: Return Value.
2457 OpFoldResult ofr = createOrFoldDimOp(b, loc, getInput(), dim);
2458 shapes.push_back(getValueOrCreateConstantIndexOp(b, loc, ofr));
2459 }
2460 }
2461 reifiedReturnShapes.emplace_back(std::move(shapes));
2462 return success();
2463}
2464
2465void SoftmaxOp::getEffects(
2466 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
2467 &effects) {
2468 getGenericEffectsImpl(effects, getOperation()->getResults(), getDpsInputs(),
2469 getDpsInits());
2470}
2471
2472// Helper functions for softmax decomposition.
2473// @{
2474
2475// Helper function to produce the iterator types (reduction or parallel) and
2476// affine maps for the iterators used in the decomposition of softmax.
2477// This method creates:
2478// If allParallel == true:
2479// - iterator type: {parallel, ..., parallel}
2480// - affine maps:
2481// -- identity with inputRank dimensions.
2482// -- (d0, ..., dN) -> (d0, ..., d_dim-1, d_dim+1, ..., dN),
2483// where N == inputRank.
2484//
2485// If allParallel == false:
2486// - iterator type at dim(i) == parallel for i != \p dim and
2487// dim(dim) == reduction.
2488// - affine map:
2489// -- identity with inputRank dimensions.
2490// -- (d0, ..., dN) -> (d0, ..., d_dim-1, d_dim+1, ..., dN),
2491// where N == inputRank.
2492static std::tuple<SmallVector<utils::IteratorType>, SmallVector<AffineMap>>
2493computeIteratorTypesAndIndexingMaps(OpBuilder &builder, int64_t inputRank,
2494 int64_t dim, bool allParallel = false) {
2495 SmallVector<utils::IteratorType> iteratorTypes(inputRank,
2496 utils::IteratorType::parallel);
2497 if (!allParallel)
2498 iteratorTypes[dim] = utils::IteratorType::reduction;
2499 MLIRContext *ctxt = builder.getContext();
2500 auto identityMap = AffineMap::getMultiDimIdentityMap(inputRank, ctxt);
2501 SmallVector<AffineExpr, 2> affineExprs;
2502 for (int i = 0; i < inputRank; i++) {
2503 if (i != dim)
2504 affineExprs.push_back(mlir::getAffineDimExpr(i, ctxt));
2505 }
2506 auto reductionMap =
2507 AffineMap::get(inputRank, /*symbols=*/0, affineExprs, ctxt);
2508 SmallVector<AffineMap> indexingMaps{identityMap, reductionMap};
2509 return std::make_tuple(iteratorTypes, indexingMaps);
2510}
2511
2512// Helper function to produce a linalg.generic that computes a reduction on
2513// dimension \p dim with the operation type \p T.
2514template <typename T>
2515static Value reduce(OpBuilder &builder, Location loc, Value input, Value output,
2516 int64_t dim) {
2517 auto inputType = cast<ShapedType>(input.getType());
2518 ArrayRef<int64_t> inputShape = inputType.getShape();
2519 int64_t inputRank = inputShape.size();
2520 auto [iteratorTypes, indexingMaps] =
2521 computeIteratorTypesAndIndexingMaps(builder, inputRank, dim);
2522 assert(indexingMaps.size() == 2 &&
2523 "We should have two maps: 1 for the input, 1 for the output");
2524 assert(indexingMaps[0].isIdentity() && "input map should be identity");
2525
2526 auto genericOp = builder.create<linalg::GenericOp>(
2527 loc, output.getType(), input, output, indexingMaps, iteratorTypes,
2528 [&](OpBuilder &b, Location loc, ValueRange args) {
2529 Value result = b.create<T>(loc, args[0], args[1]);
2530 b.create<linalg::YieldOp>(loc, result);
2531 });
2532 return genericOp.getResult(0);
2533}
2534
2535/// Produce a linalg generic that computes the second step of the softmax
2536/// decomposition: res = exp(input - max), where \p max is the max of \p input
2537/// on dimension \p dim.
2538static Value buildSubAndExpOp(OpBuilder &builder, Location loc, Value input,
2539 Value max, Value output, int64_t dim) {
2540 auto inputType = cast<ShapedType>(input.getType());
2541 ArrayRef<int64_t> inputShape = inputType.getShape();
2542 int64_t inputRank = inputShape.size();
2543 auto [iteratorTypes, indexingMaps] = computeIteratorTypesAndIndexingMaps(
2544 builder, inputRank, dim, /*allParallel=*/true);
2545 assert(indexingMaps.size() == 2 && "We should have one map for each input");
2546 assert(indexingMaps[0].isIdentity() && "input map should be identity");
2547 // Add the affine map for the output argument.
2548 indexingMaps.push_back(indexingMaps[0]);
2549 auto genericOp = builder.create<linalg::GenericOp>(
2550 loc, input.getType(), ValueRange{input, max}, output, indexingMaps,
2551 iteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) {
2552 Value diff = b.create<arith::SubFOp>(loc, args[0], args[1]);
2553 Value result = b.create<math::ExpOp>(loc, diff);
2554 b.create<linalg::YieldOp>(loc, result);
2555 });
2556 return genericOp.getResult(0);
2557}
2558
2559/// Produce a linalg generic that computes the final step of the softmax
2560/// decomposition.
2561/// \returns linalg.generic ins(\p numerator, \p denominator) outs(\p output) {
2562/// yield n / d
2563/// }
2564static Value buildDivOp(OpBuilder &builder, Location loc, Value numerator,
2565 Value denominator, Value output, int64_t dim) {
2566 auto inputType = cast<ShapedType>(numerator.getType());
2567 ArrayRef<int64_t> inputShape = inputType.getShape();
2568 int64_t inputRank = inputShape.size();
2569 auto [iteratorTypes, indexingMaps] = computeIteratorTypesAndIndexingMaps(
2570 builder, inputRank, dim, /*allParallel=*/true);
2571 assert(indexingMaps.size() == 2 &&
2572 "We should have one map for each input (2)");
2573 assert(indexingMaps[0].isIdentity() && "Numerator map should be identity");
2574 // Add the affine map for the output tensor.
2575 indexingMaps.push_back(indexingMaps[0]);
2576 auto genericOp = builder.create<linalg::GenericOp>(
2577 loc, numerator.getType(), ValueRange{numerator, denominator}, output,
2578 indexingMaps, iteratorTypes,
2579 [&](OpBuilder &b, Location loc, ValueRange args) {
2580 Value result = b.create<arith::DivFOp>(loc, args[0], args[1]);
2581 b.create<linalg::YieldOp>(loc, result);
2582 });
2583 return genericOp.getResult(0);
2584}
2585// @} End helper functions for softmax decomposition.
2586
2587/// Given an N-dimensional tensor x, this method converts
2588/// softmax(x) to the following sequence of operations:
2589///
2590/// 1. Compute the max of x along dimension d. This results
2591/// in a N-1 dimensional tensor m.
2592/// m = max(x, dim = d)
2593///
2594/// 2. Subtract a broadcasted m from x and exponentiate. This results in
2595/// a N dimensional tensor z.
2596/// z = exp(x - m)
2597///
2598/// 3. Compute the sum of z along dimension d. This results in
2599/// a N-1 dimensional tensor l.
2600/// l = sum(z, dim = d)
2601///
2602/// 4. Divide z and l. This gives the N-dimensional softmax.
2603/// softmax = z / l
2604///
2605FailureOr<SmallVector<Value>> SoftmaxOp::decomposeOperation(OpBuilder &b) {
2606 OpBuilder::InsertionGuard guard(b);
2607 b.setInsertionPoint(*this);
2608 Location loc = getLoc();
2609 Value input = getInput();
2610 ShapedType inputType = getInputOperandType();
2611 Type elementType = inputType.getElementType();
2612 int64_t reductionDim = getDimension();
2613 SmallVector<OpFoldResult> dims = tensor::getMixedSizes(b, loc, input);
2614 Value output = getOutput();
2615 dims.erase(dims.begin() + reductionDim);
2616 // Step 1: Compute max along dim.
2617 Value outputReduce = b.create<tensor::EmptyOp>(loc, dims, elementType);
2618 Value neutralForMaxF = arith::getIdentityValue(arith::AtomicRMWKind::maximumf,
2619 elementType, b, loc,
2620 /*useOnlyFiniteValue=*/true);
2621 Value neutralForMaxFInit =
2622 b.create<linalg::FillOp>(loc, Value{neutralForMaxF}, outputReduce)
2623 .result();
2624 Value max = reduce<arith::MaximumFOp>(b, loc, input, neutralForMaxFInit,
2625 reductionDim);
2626
2627 // Step 2: Subtract max from input and exponentiate.
2628 Value numerator = buildSubAndExpOp(b, loc, input, max, output, reductionDim);
2629
2630 // Step 3: Compute sum along dim.
2631 Value zero = arith::getIdentityValue(arith::AtomicRMWKind::addf, elementType,
2632 b, loc, /*useOnlyFiniteValue=*/true);
2633 Value zeroInit =
2634 b.create<linalg::FillOp>(loc, Value{zero}, outputReduce).result();
2635 Value denominator =
2636 reduce<arith::AddFOp>(b, loc, numerator, zeroInit, reductionDim);
2637
2638 // Step 4: Compute softmax.
2639 Value result =
2640 buildDivOp(b, loc, numerator, denominator, output, reductionDim);
2641 return SmallVector<Value>{result};
2642}
2643
2644//===----------------------------------------------------------------------===//
2645// LinalgDialect
2646//===----------------------------------------------------------------------===//
2647
2648void LinalgDialect::getCanonicalizationPatterns(
2649 RewritePatternSet &results) const {
2650 results.add<EraseDeadLinalgOp, FoldTensorCastConsumerOp,
2651 InferStaticShapeOfOperands>(getContext());
2652}
2653
2654Operation *LinalgDialect::materializeConstant(OpBuilder &builder,
2655 Attribute value, Type type,
2656 Location loc) {
2657 return arith::ConstantOp::materialize(builder, value, type, loc);
2658}
2659

source code of mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp