1//===- AffineOps.h - MLIR Affine Operations -------------------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file defines convenience types for working with Affine operations
10// in the MLIR operation set.
11//
12//===----------------------------------------------------------------------===//
13
14#ifndef MLIR_DIALECT_AFFINE_IR_AFFINEOPS_H
15#define MLIR_DIALECT_AFFINE_IR_AFFINEOPS_H
16
17#include "mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.h"
18#include "mlir/Dialect/Arith/IR/Arith.h"
19#include "mlir/IR/AffineMap.h"
20#include "mlir/IR/Builders.h"
21#include "mlir/Interfaces/ControlFlowInterfaces.h"
22#include "mlir/Interfaces/LoopLikeInterface.h"
23
24namespace mlir {
25namespace affine {
26
27class AffineApplyOp;
28class AffineBound;
29class AffineMaxOp;
30class AffineMinOp;
31class AffineValueMap;
32
33/// A utility function to check if a value is defined at the top level of an
34/// op with trait `AffineScope` or is a region argument for such an op. A value
35/// of index type defined at the top level is always a valid symbol for all its
36/// uses.
37bool isTopLevelValue(Value value);
38
39/// A utility function to check if a value is defined at the top level of
40/// `region` or is an argument of `region`. A value of index type defined at the
41/// top level of a `AffineScope` region is always a valid symbol for all
42/// uses in that region.
43bool isTopLevelValue(Value value, Region *region);
44
45/// Returns the closest region enclosing `op` that is held by an operation with
46/// trait `AffineScope`; `nullptr` if there is no such region.
47Region *getAffineScope(Operation *op);
48
49/// AffineDmaStartOp starts a non-blocking DMA operation that transfers data
50/// from a source memref to a destination memref. The source and destination
51/// memref need not be of the same dimensionality, but need to have the same
52/// elemental type. The operands include the source and destination memref's
53/// each followed by its indices, size of the data transfer in terms of the
54/// number of elements (of the elemental type of the memref), a tag memref with
55/// its indices, and optionally at the end, a stride and a
56/// number_of_elements_per_stride arguments. The tag location is used by an
57/// AffineDmaWaitOp to check for completion. The indices of the source memref,
58/// destination memref, and the tag memref have the same restrictions as any
59/// affine.load/store. In particular, index for each memref dimension must be an
60/// affine expression of loop induction variables and symbols.
61/// The optional stride arguments should be of 'index' type, and specify a
62/// stride for the slower memory space (memory space with a lower memory space
63/// id), transferring chunks of number_of_elements_per_stride every stride until
64/// %num_elements are transferred. Either both or no stride arguments should be
65/// specified. The value of 'num_elements' must be a multiple of
66/// 'number_of_elements_per_stride'. If the source and destination locations
67/// overlap the behavior of this operation is not defined.
68//
69// For example, an AffineDmaStartOp operation that transfers 256 elements of a
70// memref '%src' in memory space 0 at indices [%i + 3, %j] to memref '%dst' in
71// memory space 1 at indices [%k + 7, %l], would be specified as follows:
72//
73// %num_elements = arith.constant 256
74// %idx = arith.constant 0 : index
75// %tag = memref.alloc() : memref<1xi32, 4>
76// affine.dma_start %src[%i + 3, %j], %dst[%k + 7, %l], %tag[%idx],
77// %num_elements :
78// memref<40x128xf32, 0>, memref<2x1024xf32, 1>, memref<1xi32, 2>
79//
80// If %stride and %num_elt_per_stride are specified, the DMA is expected to
81// transfer %num_elt_per_stride elements every %stride elements apart from
82// memory space 0 until %num_elements are transferred.
83//
84// affine.dma_start %src[%i, %j], %dst[%k, %l], %tag[%idx], %num_elements,
85// %stride, %num_elt_per_stride : ...
86//
87// TODO: add additional operands to allow source and destination striding, and
88// multiple stride levels (possibly using AffineMaps to specify multiple levels
89// of striding).
90class AffineDmaStartOp
91 : public Op<AffineDmaStartOp, OpTrait::MemRefsNormalizable,
92 OpTrait::VariadicOperands, OpTrait::ZeroResults,
93 OpTrait::OpInvariants, AffineMapAccessInterface::Trait,
94 MemoryEffectOpInterface::Trait> {
95public:
96 using Op::Op;
97 static ArrayRef<StringRef> getAttributeNames() { return {}; }
98
99 static void build(OpBuilder &builder, OperationState &result, Value srcMemRef,
100 AffineMap srcMap, ValueRange srcIndices, Value destMemRef,
101 AffineMap dstMap, ValueRange destIndices, Value tagMemRef,
102 AffineMap tagMap, ValueRange tagIndices, Value numElements,
103 Value stride = nullptr, Value elementsPerStride = nullptr);
104
105 /// Returns the operand index of the source memref.
106 unsigned getSrcMemRefOperandIndex() { return 0; }
107
108 /// Returns the source MemRefType for this DMA operation.
109 Value getSrcMemRef() { return getOperand(getSrcMemRefOperandIndex()); }
110 MemRefType getSrcMemRefType() {
111 return cast<MemRefType>(getSrcMemRef().getType());
112 }
113
114 /// Returns the rank (number of indices) of the source MemRefType.
115 unsigned getSrcMemRefRank() { return getSrcMemRefType().getRank(); }
116
117 /// Returns the affine map used to access the source memref.
118 AffineMap getSrcMap() { return getSrcMapAttr().getValue(); }
119 AffineMapAttr getSrcMapAttr() {
120 return cast<AffineMapAttr>(*(*this)->getInherentAttr(getSrcMapAttrStrName()));
121 }
122
123 /// Returns the source memref affine map indices for this DMA operation.
124 operand_range getSrcIndices() {
125 return {operand_begin() + getSrcMemRefOperandIndex() + 1,
126 operand_begin() + getSrcMemRefOperandIndex() + 1 +
127 getSrcMap().getNumInputs()};
128 }
129
130 /// Returns the memory space of the source memref.
131 unsigned getSrcMemorySpace() {
132 return cast<MemRefType>(getSrcMemRef().getType()).getMemorySpaceAsInt();
133 }
134
135 /// Returns the operand index of the destination memref.
136 unsigned getDstMemRefOperandIndex() {
137 return getSrcMemRefOperandIndex() + 1 + getSrcMap().getNumInputs();
138 }
139
140 /// Returns the destination MemRefType for this DMA operation.
141 Value getDstMemRef() { return getOperand(getDstMemRefOperandIndex()); }
142 MemRefType getDstMemRefType() {
143 return cast<MemRefType>(getDstMemRef().getType());
144 }
145
146 /// Returns the rank (number of indices) of the destination MemRefType.
147 unsigned getDstMemRefRank() {
148 return cast<MemRefType>(getDstMemRef().getType()).getRank();
149 }
150
151 /// Returns the memory space of the source memref.
152 unsigned getDstMemorySpace() {
153 return cast<MemRefType>(getDstMemRef().getType()).getMemorySpaceAsInt();
154 }
155
156 /// Returns the affine map used to access the destination memref.
157 AffineMap getDstMap() { return getDstMapAttr().getValue(); }
158 AffineMapAttr getDstMapAttr() {
159 return cast<AffineMapAttr>(*(*this)->getInherentAttr(getDstMapAttrStrName()));
160 }
161
162 /// Returns the destination memref indices for this DMA operation.
163 operand_range getDstIndices() {
164 return {operand_begin() + getDstMemRefOperandIndex() + 1,
165 operand_begin() + getDstMemRefOperandIndex() + 1 +
166 getDstMap().getNumInputs()};
167 }
168
169 /// Returns the operand index of the tag memref.
170 unsigned getTagMemRefOperandIndex() {
171 return getDstMemRefOperandIndex() + 1 + getDstMap().getNumInputs();
172 }
173
174 /// Returns the Tag MemRef for this DMA operation.
175 Value getTagMemRef() { return getOperand(getTagMemRefOperandIndex()); }
176 MemRefType getTagMemRefType() {
177 return cast<MemRefType>(getTagMemRef().getType());
178 }
179
180 /// Returns the rank (number of indices) of the tag MemRefType.
181 unsigned getTagMemRefRank() {
182 return cast<MemRefType>(getTagMemRef().getType()).getRank();
183 }
184
185 /// Returns the affine map used to access the tag memref.
186 AffineMap getTagMap() { return getTagMapAttr().getValue(); }
187 AffineMapAttr getTagMapAttr() {
188 return cast<AffineMapAttr>(*(*this)->getInherentAttr(getTagMapAttrStrName()));
189 }
190
191 /// Returns the tag memref indices for this DMA operation.
192 operand_range getTagIndices() {
193 return {operand_begin() + getTagMemRefOperandIndex() + 1,
194 operand_begin() + getTagMemRefOperandIndex() + 1 +
195 getTagMap().getNumInputs()};
196 }
197
198 /// Returns the number of elements being transferred by this DMA operation.
199 Value getNumElements() {
200 return getOperand(getTagMemRefOperandIndex() + 1 +
201 getTagMap().getNumInputs());
202 }
203
204 /// Impelements the AffineMapAccessInterface.
205 /// Returns the AffineMapAttr associated with 'memref'.
206 NamedAttribute getAffineMapAttrForMemRef(Value memref) {
207 if (memref == getSrcMemRef())
208 return {StringAttr::get(getContext(), getSrcMapAttrStrName()),
209 getSrcMapAttr()};
210 if (memref == getDstMemRef())
211 return {StringAttr::get(getContext(), getDstMapAttrStrName()),
212 getDstMapAttr()};
213 assert(memref == getTagMemRef() &&
214 "DmaStartOp expected source, destination or tag memref");
215 return {StringAttr::get(getContext(), getTagMapAttrStrName()),
216 getTagMapAttr()};
217 }
218
219 /// Returns true if this is a DMA from a faster memory space to a slower one.
220 bool isDestMemorySpaceFaster() {
221 return (getSrcMemorySpace() < getDstMemorySpace());
222 }
223
224 /// Returns true if this is a DMA from a slower memory space to a faster one.
225 bool isSrcMemorySpaceFaster() {
226 // Assumes that a lower number is for a slower memory space.
227 return (getDstMemorySpace() < getSrcMemorySpace());
228 }
229
230 /// Given a DMA start operation, returns the operand position of either the
231 /// source or destination memref depending on the one that is at the higher
232 /// level of the memory hierarchy. Asserts failure if neither is true.
233 unsigned getFasterMemPos() {
234 assert(isSrcMemorySpaceFaster() || isDestMemorySpaceFaster());
235 return isSrcMemorySpaceFaster() ? 0 : getDstMemRefOperandIndex();
236 }
237
238 void
239 getEffects(SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
240 &effects);
241
242 static StringRef getSrcMapAttrStrName() { return "src_map"; }
243 static StringRef getDstMapAttrStrName() { return "dst_map"; }
244 static StringRef getTagMapAttrStrName() { return "tag_map"; }
245
246 static StringRef getOperationName() { return "affine.dma_start"; }
247 static ParseResult parse(OpAsmParser &parser, OperationState &result);
248 void print(OpAsmPrinter &p);
249 LogicalResult verifyInvariantsImpl();
250 LogicalResult verifyInvariants() { return verifyInvariantsImpl(); }
251 LogicalResult fold(ArrayRef<Attribute> cstOperands,
252 SmallVectorImpl<OpFoldResult> &results);
253
254 /// Returns true if this DMA operation is strided, returns false otherwise.
255 bool isStrided() {
256 return getNumOperands() !=
257 getTagMemRefOperandIndex() + 1 + getTagMap().getNumInputs() + 1;
258 }
259
260 /// Returns the stride value for this DMA operation.
261 Value getStride() {
262 if (!isStrided())
263 return nullptr;
264 return getOperand(getNumOperands() - 1 - 1);
265 }
266
267 /// Returns the number of elements to transfer per stride for this DMA op.
268 Value getNumElementsPerStride() {
269 if (!isStrided())
270 return nullptr;
271 return getOperand(getNumOperands() - 1);
272 }
273};
274
275/// AffineDmaWaitOp blocks until the completion of a DMA operation associated
276/// with the tag element '%tag[%index]'. %tag is a memref, and %index has to be
277/// an index with the same restrictions as any load/store index. In particular,
278/// index for each memref dimension must be an affine expression of loop
279/// induction variables and symbols. %num_elements is the number of elements
280/// associated with the DMA operation. For example:
281//
282// affine.dma_start %src[%i, %j], %dst[%k, %l], %tag[%index], %num_elements :
283// memref<2048xf32, 0>, memref<256xf32, 1>, memref<1xi32, 2>
284// ...
285// ...
286// affine.dma_wait %tag[%index], %num_elements : memref<1xi32, 2>
287//
288class AffineDmaWaitOp
289 : public Op<AffineDmaWaitOp, OpTrait::MemRefsNormalizable,
290 OpTrait::VariadicOperands, OpTrait::ZeroResults,
291 OpTrait::OpInvariants, AffineMapAccessInterface::Trait> {
292public:
293 using Op::Op;
294 static ArrayRef<StringRef> getAttributeNames() { return {}; }
295
296 static void build(OpBuilder &builder, OperationState &result, Value tagMemRef,
297 AffineMap tagMap, ValueRange tagIndices, Value numElements);
298
299 static StringRef getOperationName() { return "affine.dma_wait"; }
300
301 /// Returns the Tag MemRef associated with the DMA operation being waited on.
302 Value getTagMemRef() { return getOperand(0); }
303 MemRefType getTagMemRefType() {
304 return cast<MemRefType>(getTagMemRef().getType());
305 }
306
307 /// Returns the affine map used to access the tag memref.
308 AffineMap getTagMap() { return getTagMapAttr().getValue(); }
309 AffineMapAttr getTagMapAttr() {
310 return cast<AffineMapAttr>(*(*this)->getInherentAttr(getTagMapAttrStrName()));
311 }
312
313 /// Returns the tag memref index for this DMA operation.
314 operand_range getTagIndices() {
315 return {operand_begin() + 1,
316 operand_begin() + 1 + getTagMap().getNumInputs()};
317 }
318
319 /// Returns the rank (number of indices) of the tag memref.
320 unsigned getTagMemRefRank() {
321 return cast<MemRefType>(getTagMemRef().getType()).getRank();
322 }
323
324 /// Impelements the AffineMapAccessInterface. Returns the AffineMapAttr
325 /// associated with 'memref'.
326 NamedAttribute getAffineMapAttrForMemRef(Value memref) {
327 assert(memref == getTagMemRef());
328 return {StringAttr::get(getContext(), getTagMapAttrStrName()),
329 getTagMapAttr()};
330 }
331
332 /// Returns the number of elements transferred by the associated DMA op.
333 Value getNumElements() { return getOperand(1 + getTagMap().getNumInputs()); }
334
335 static StringRef getTagMapAttrStrName() { return "tag_map"; }
336 static ParseResult parse(OpAsmParser &parser, OperationState &result);
337 void print(OpAsmPrinter &p);
338 LogicalResult verifyInvariantsImpl();
339 LogicalResult verifyInvariants() { return verifyInvariantsImpl(); }
340 LogicalResult fold(ArrayRef<Attribute> cstOperands,
341 SmallVectorImpl<OpFoldResult> &results);
342 void
343 getEffects(SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
344 &effects);
345};
346
347/// Returns true if the given Value can be used as a dimension id in the region
348/// of the closest surrounding op that has the trait `AffineScope`.
349bool isValidDim(Value value);
350
351/// Returns true if the given Value can be used as a dimension id in `region`,
352/// i.e., for all its uses in `region`.
353bool isValidDim(Value value, Region *region);
354
355/// Returns true if the given value can be used as a symbol in the region of the
356/// closest surrounding op that has the trait `AffineScope`.
357bool isValidSymbol(Value value);
358
359/// Returns true if the given Value can be used as a symbol for `region`, i.e.,
360/// for all its uses in `region`.
361bool isValidSymbol(Value value, Region *region);
362
363/// Parses dimension and symbol list. `numDims` is set to the number of
364/// dimensions in the list parsed.
365ParseResult parseDimAndSymbolList(OpAsmParser &parser,
366 SmallVectorImpl<Value> &operands,
367 unsigned &numDims);
368
369/// Modifies both `map` and `operands` in-place so as to:
370/// 1. drop duplicate operands
371/// 2. drop unused dims and symbols from map
372/// 3. promote valid symbols to symbolic operands in case they appeared as
373/// dimensional operands
374/// 4. propagate constant operands and drop them
375void canonicalizeMapAndOperands(AffineMap *map,
376 SmallVectorImpl<Value> *operands);
377
378/// Canonicalizes an integer set the same way canonicalizeMapAndOperands does
379/// for affine maps.
380void canonicalizeSetAndOperands(IntegerSet *set,
381 SmallVectorImpl<Value> *operands);
382
383/// Returns a composed AffineApplyOp by composing `map` and `operands` with
384/// other AffineApplyOps supplying those operands. The operands of the resulting
385/// AffineApplyOp do not change the length of AffineApplyOp chains.
386AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map,
387 ArrayRef<OpFoldResult> operands);
388AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineExpr e,
389 ArrayRef<OpFoldResult> operands);
390
391/// Constructs an AffineApplyOp that applies `map` to `operands` after composing
392/// the map with the maps of any other AffineApplyOp supplying the operands,
393/// then immediately attempts to fold it. If folding results in a constant
394/// value, no ops are actually created. The `map` must be a single-result affine
395/// map.
396OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc,
397 AffineMap map,
398 ArrayRef<OpFoldResult> operands);
399/// Variant of `makeComposedFoldedAffineApply` that applies to an expression.
400OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc,
401 AffineExpr expr,
402 ArrayRef<OpFoldResult> operands);
403/// Variant of `makeComposedFoldedAffineApply` suitable for multi-result maps.
404/// Note that this may create as many affine.apply operations as the map has
405/// results given that affine.apply must be single-result.
406SmallVector<OpFoldResult> makeComposedFoldedMultiResultAffineApply(
407 OpBuilder &b, Location loc, AffineMap map, ArrayRef<OpFoldResult> operands);
408
409/// Returns an AffineMinOp obtained by composing `map` and `operands` with
410/// AffineApplyOps supplying those operands.
411AffineMinOp makeComposedAffineMin(OpBuilder &b, Location loc, AffineMap map,
412 ArrayRef<OpFoldResult> operands);
413
414/// Constructs an AffineMinOp that computes a minimum across the results of
415/// applying `map` to `operands`, then immediately attempts to fold it. If
416/// folding results in a constant value, no ops are actually created.
417OpFoldResult makeComposedFoldedAffineMin(OpBuilder &b, Location loc,
418 AffineMap map,
419 ArrayRef<OpFoldResult> operands);
420
421/// Constructs an AffineMinOp that computes a maximum across the results of
422/// applying `map` to `operands`, then immediately attempts to fold it. If
423/// folding results in a constant value, no ops are actually created.
424OpFoldResult makeComposedFoldedAffineMax(OpBuilder &b, Location loc,
425 AffineMap map,
426 ArrayRef<OpFoldResult> operands);
427
428/// Given an affine map `map` and its input `operands`, this method composes
429/// into `map`, maps of AffineApplyOps whose results are the values in
430/// `operands`, iteratively until no more of `operands` are the result of an
431/// AffineApplyOp. When this function returns, `map` becomes the composed affine
432/// map, and each Value in `operands` is guaranteed to be either a loop IV or a
433/// terminal symbol, i.e., a symbol defined at the top level or a block/function
434/// argument.
435void fullyComposeAffineMapAndOperands(AffineMap *map,
436 SmallVectorImpl<Value> *operands);
437
438} // namespace affine
439} // namespace mlir
440
441#include "mlir/Dialect/Affine/IR/AffineOpsDialect.h.inc"
442
443#define GET_OP_CLASSES
444#include "mlir/Dialect/Affine/IR/AffineOps.h.inc"
445
446namespace mlir {
447namespace affine {
448
449/// Returns true if the provided value is the induction variable of an
450/// AffineForOp.
451bool isAffineForInductionVar(Value val);
452
453/// Returns true if `val` is the induction variable of an AffineParallelOp.
454bool isAffineParallelInductionVar(Value val);
455
456/// Returns true if the provided value is the induction variable of an
457/// AffineForOp or AffineParallelOp.
458bool isAffineInductionVar(Value val);
459
460/// Returns the loop parent of an induction variable. If the provided value is
461/// not an induction variable, then return nullptr.
462AffineForOp getForInductionVarOwner(Value val);
463
464/// Returns true if the provided value is among the induction variables of an
465/// AffineParallelOp.
466AffineParallelOp getAffineParallelInductionVarOwner(Value val);
467
468/// Extracts the induction variables from a list of AffineForOps and places them
469/// in the output argument `ivs`.
470void extractForInductionVars(ArrayRef<AffineForOp> forInsts,
471 SmallVectorImpl<Value> *ivs);
472
473/// Extracts the induction variables from a list of either AffineForOp or
474/// AffineParallelOp and places them in the output argument `ivs`.
475void extractInductionVars(ArrayRef<Operation *> affineOps,
476 SmallVectorImpl<Value> &ivs);
477
478/// Builds a perfect nest of affine.for loops, i.e., each loop except the
479/// innermost one contains only another loop and a terminator. The loops iterate
480/// from "lbs" to "ubs" with "steps". The body of the innermost loop is
481/// populated by calling "bodyBuilderFn" and providing it with an OpBuilder, a
482/// Location and a list of loop induction variables.
483void buildAffineLoopNest(OpBuilder &builder, Location loc,
484 ArrayRef<int64_t> lbs, ArrayRef<int64_t> ubs,
485 ArrayRef<int64_t> steps,
486 function_ref<void(OpBuilder &, Location, ValueRange)>
487 bodyBuilderFn = nullptr);
488void buildAffineLoopNest(OpBuilder &builder, Location loc, ValueRange lbs,
489 ValueRange ubs, ArrayRef<int64_t> steps,
490 function_ref<void(OpBuilder &, Location, ValueRange)>
491 bodyBuilderFn = nullptr);
492
493/// AffineBound represents a lower or upper bound in the for operation.
494/// This class does not own the underlying operands. Instead, it refers
495/// to the operands stored in the AffineForOp. Its life span should not exceed
496/// that of the for operation it refers to.
497class AffineBound {
498public:
499 AffineForOp getAffineForOp() { return op; }
500 AffineMap getMap() { return map; }
501
502 unsigned getNumOperands() { return operands.size(); }
503 Value getOperand(unsigned idx) {
504 return op.getOperand(operands.getBeginOperandIndex() + idx);
505 }
506
507 using operand_iterator = AffineForOp::operand_iterator;
508 using operand_range = AffineForOp::operand_range;
509
510 operand_iterator operandBegin() { return operands.begin(); }
511 operand_iterator operandEnd() { return operands.end(); }
512 operand_range getOperands() { return {operandBegin(), operandEnd()}; }
513
514private:
515 // 'affine.for' operation that contains this bound.
516 AffineForOp op;
517 // Operands of the affine map.
518 OperandRange operands;
519 // Affine map for this bound.
520 AffineMap map;
521
522 AffineBound(AffineForOp op, OperandRange operands, AffineMap map)
523 : op(op), operands(operands), map(map) {}
524
525 friend class AffineForOp;
526};
527
528} // namespace affine
529} // namespace mlir
530
531#endif
532

source code of mlir/include/mlir/Dialect/Affine/IR/AffineOps.h