1 | //===-- AffinePromotion.cpp -----------------------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This transformation is a prototype that promote FIR loops operations |
10 | // to affine dialect operations. |
11 | // It is not part of the production pipeline and would need more work in order |
12 | // to be used in production. |
13 | // More information can be found in this presentation: |
14 | // https://slides.com/rajanwalia/deck |
15 | // |
16 | //===----------------------------------------------------------------------===// |
17 | |
18 | #include "flang/Optimizer/Dialect/FIRDialect.h" |
19 | #include "flang/Optimizer/Dialect/FIROps.h" |
20 | #include "flang/Optimizer/Dialect/FIRType.h" |
21 | #include "flang/Optimizer/Transforms/Passes.h" |
22 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
23 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
24 | #include "mlir/Dialect/SCF/IR/SCF.h" |
25 | #include "mlir/IR/BuiltinAttributes.h" |
26 | #include "mlir/IR/IntegerSet.h" |
27 | #include "mlir/IR/Visitors.h" |
28 | #include "mlir/Transforms/DialectConversion.h" |
29 | #include "llvm/ADT/DenseMap.h" |
30 | #include "llvm/Support/Debug.h" |
31 | #include <optional> |
32 | |
33 | namespace fir { |
34 | #define GEN_PASS_DEF_AFFINEDIALECTPROMOTION |
35 | #include "flang/Optimizer/Transforms/Passes.h.inc" |
36 | } // namespace fir |
37 | |
38 | #define DEBUG_TYPE "flang-affine-promotion" |
39 | |
40 | using namespace fir; |
41 | using namespace mlir; |
42 | |
43 | namespace { |
44 | struct AffineLoopAnalysis; |
45 | struct AffineIfAnalysis; |
46 | |
47 | /// Stores analysis objects for all loops and if operations inside a function |
48 | /// these analysis are used twice, first for marking operations for rewrite and |
49 | /// second when doing rewrite. |
50 | struct AffineFunctionAnalysis { |
51 | explicit AffineFunctionAnalysis(mlir::func::FuncOp funcOp) { |
52 | funcOp->walk([&](fir::DoLoopOp doloop) { |
53 | loopAnalysisMap.try_emplace(doloop, doloop, *this); |
54 | }); |
55 | } |
56 | |
57 | AffineLoopAnalysis getChildLoopAnalysis(fir::DoLoopOp op) const; |
58 | |
59 | AffineIfAnalysis getChildIfAnalysis(fir::IfOp op) const; |
60 | |
61 | llvm::DenseMap<mlir::Operation *, AffineLoopAnalysis> loopAnalysisMap; |
62 | llvm::DenseMap<mlir::Operation *, AffineIfAnalysis> ifAnalysisMap; |
63 | }; |
64 | } // namespace |
65 | |
66 | static bool analyzeCoordinate(mlir::Value coordinate, mlir::Operation *op) { |
67 | if (auto blockArg = mlir::dyn_cast<mlir::BlockArgument>(coordinate)) { |
68 | if (isa<fir::DoLoopOp>(blockArg.getOwner()->getParentOp())) |
69 | return true; |
70 | LLVM_DEBUG(llvm::dbgs() << "AffineLoopAnalysis: array coordinate is not a " |
71 | "loop induction variable (owner not loopOp)\n" ; |
72 | op->dump()); |
73 | return false; |
74 | } |
75 | LLVM_DEBUG( |
76 | llvm::dbgs() << "AffineLoopAnalysis: array coordinate is not a loop " |
77 | "induction variable (not a block argument)\n" ; |
78 | op->dump(); coordinate.getDefiningOp()->dump()); |
79 | return false; |
80 | } |
81 | |
82 | namespace { |
83 | struct AffineLoopAnalysis { |
84 | AffineLoopAnalysis() = default; |
85 | |
86 | explicit AffineLoopAnalysis(fir::DoLoopOp op, AffineFunctionAnalysis &afa) |
87 | : legality(analyzeLoop(op, afa)) {} |
88 | |
89 | bool canPromoteToAffine() { return legality; } |
90 | |
91 | private: |
92 | bool analyzeBody(fir::DoLoopOp loopOperation, |
93 | AffineFunctionAnalysis &functionAnalysis) { |
94 | for (auto loopOp : loopOperation.getOps<fir::DoLoopOp>()) { |
95 | auto analysis = functionAnalysis.loopAnalysisMap |
96 | .try_emplace(loopOp, loopOp, functionAnalysis) |
97 | .first->getSecond(); |
98 | if (!analysis.canPromoteToAffine()) |
99 | return false; |
100 | } |
101 | for (auto ifOp : loopOperation.getOps<fir::IfOp>()) |
102 | functionAnalysis.ifAnalysisMap.try_emplace(ifOp, ifOp, functionAnalysis); |
103 | return true; |
104 | } |
105 | |
106 | bool analysisResults(fir::DoLoopOp loopOperation) { |
107 | if (loopOperation.getFinalValue() && |
108 | !loopOperation.getResult(0).use_empty()) { |
109 | LLVM_DEBUG( |
110 | llvm::dbgs() |
111 | << "AffineLoopAnalysis: cannot promote loop final value\n" ;); |
112 | return false; |
113 | } |
114 | |
115 | return true; |
116 | } |
117 | |
118 | bool analyzeLoop(fir::DoLoopOp loopOperation, |
119 | AffineFunctionAnalysis &functionAnalysis) { |
120 | LLVM_DEBUG(llvm::dbgs() << "AffineLoopAnalysis: \n" ; loopOperation.dump();); |
121 | return analyzeMemoryAccess(loopOperation) && |
122 | analysisResults(loopOperation) && |
123 | analyzeBody(loopOperation, functionAnalysis); |
124 | } |
125 | |
126 | bool analyzeReference(mlir::Value memref, mlir::Operation *op) { |
127 | if (auto acoOp = memref.getDefiningOp<ArrayCoorOp>()) { |
128 | if (mlir::isa<fir::BoxType>(acoOp.getMemref().getType())) { |
129 | // TODO: Look if and how fir.box can be promoted to affine. |
130 | LLVM_DEBUG(llvm::dbgs() << "AffineLoopAnalysis: cannot promote loop, " |
131 | "array memory operation uses fir.box\n" ; |
132 | op->dump(); acoOp.dump();); |
133 | return false; |
134 | } |
135 | bool canPromote = true; |
136 | for (auto coordinate : acoOp.getIndices()) |
137 | canPromote = canPromote && analyzeCoordinate(coordinate, op); |
138 | return canPromote; |
139 | } |
140 | if (auto coOp = memref.getDefiningOp<CoordinateOp>()) { |
141 | LLVM_DEBUG(llvm::dbgs() |
142 | << "AffineLoopAnalysis: cannot promote loop, " |
143 | "array memory operation uses non ArrayCoorOp\n" ; |
144 | op->dump(); coOp.dump();); |
145 | |
146 | return false; |
147 | } |
148 | LLVM_DEBUG(llvm::dbgs() << "AffineLoopAnalysis: unknown type of memory " |
149 | "reference for array load\n" ; |
150 | op->dump();); |
151 | return false; |
152 | } |
153 | |
154 | bool analyzeMemoryAccess(fir::DoLoopOp loopOperation) { |
155 | for (auto loadOp : loopOperation.getOps<fir::LoadOp>()) |
156 | if (!analyzeReference(loadOp.getMemref(), loadOp)) |
157 | return false; |
158 | for (auto storeOp : loopOperation.getOps<fir::StoreOp>()) |
159 | if (!analyzeReference(storeOp.getMemref(), storeOp)) |
160 | return false; |
161 | return true; |
162 | } |
163 | |
164 | bool legality{}; |
165 | }; |
166 | } // namespace |
167 | |
168 | AffineLoopAnalysis |
169 | AffineFunctionAnalysis::getChildLoopAnalysis(fir::DoLoopOp op) const { |
170 | auto it = loopAnalysisMap.find_as(op); |
171 | if (it == loopAnalysisMap.end()) { |
172 | LLVM_DEBUG(llvm::dbgs() << "AffineFunctionAnalysis: not computed for:\n" ; |
173 | op.dump();); |
174 | op.emitError("error in fetching loop analysis in AffineFunctionAnalysis\n" ); |
175 | return {}; |
176 | } |
177 | return it->getSecond(); |
178 | } |
179 | |
180 | namespace { |
181 | /// Calculates arguments for creating an IntegerSet. symCount, dimCount are the |
182 | /// final number of symbols and dimensions of the affine map. Integer set if |
183 | /// possible is in Optional IntegerSet. |
184 | struct AffineIfCondition { |
185 | using MaybeAffineExpr = std::optional<mlir::AffineExpr>; |
186 | |
187 | explicit AffineIfCondition(mlir::Value fc) : firCondition(fc) { |
188 | if (auto condDef = firCondition.getDefiningOp<mlir::arith::CmpIOp>()) |
189 | fromCmpIOp(condDef); |
190 | } |
191 | |
192 | bool hasIntegerSet() const { return integerSet.has_value(); } |
193 | |
194 | mlir::IntegerSet getIntegerSet() const { |
195 | assert(hasIntegerSet() && "integer set is missing" ); |
196 | return *integerSet; |
197 | } |
198 | |
199 | mlir::ValueRange getAffineArgs() const { return affineArgs; } |
200 | |
201 | private: |
202 | MaybeAffineExpr affineBinaryOp(mlir::AffineExprKind kind, mlir::Value lhs, |
203 | mlir::Value rhs) { |
204 | return affineBinaryOp(kind, toAffineExpr(lhs), toAffineExpr(rhs)); |
205 | } |
206 | |
207 | MaybeAffineExpr affineBinaryOp(mlir::AffineExprKind kind, MaybeAffineExpr lhs, |
208 | MaybeAffineExpr rhs) { |
209 | if (lhs && rhs) |
210 | return mlir::getAffineBinaryOpExpr(kind, *lhs, *rhs); |
211 | return {}; |
212 | } |
213 | |
214 | MaybeAffineExpr toAffineExpr(MaybeAffineExpr e) { return e; } |
215 | |
216 | MaybeAffineExpr toAffineExpr(int64_t value) { |
217 | return {mlir::getAffineConstantExpr(value, firCondition.getContext())}; |
218 | } |
219 | |
220 | /// Returns an AffineExpr if it is a result of operations that can be done |
221 | /// in an affine expression, this includes -, +, *, rem, constant. |
222 | /// block arguments of a loopOp or forOp are used as dimensions |
223 | MaybeAffineExpr toAffineExpr(mlir::Value value) { |
224 | if (auto op = value.getDefiningOp<mlir::arith::SubIOp>()) |
225 | return affineBinaryOp( |
226 | mlir::AffineExprKind::Add, toAffineExpr(op.getLhs()), |
227 | affineBinaryOp(mlir::AffineExprKind::Mul, toAffineExpr(op.getRhs()), |
228 | toAffineExpr(-1))); |
229 | if (auto op = value.getDefiningOp<mlir::arith::AddIOp>()) |
230 | return affineBinaryOp(mlir::AffineExprKind::Add, op.getLhs(), |
231 | op.getRhs()); |
232 | if (auto op = value.getDefiningOp<mlir::arith::MulIOp>()) |
233 | return affineBinaryOp(mlir::AffineExprKind::Mul, op.getLhs(), |
234 | op.getRhs()); |
235 | if (auto op = value.getDefiningOp<mlir::arith::RemUIOp>()) |
236 | return affineBinaryOp(mlir::AffineExprKind::Mod, op.getLhs(), |
237 | op.getRhs()); |
238 | if (auto op = value.getDefiningOp<mlir::arith::ConstantOp>()) |
239 | if (auto intConstant = mlir::dyn_cast<IntegerAttr>(op.getValue())) |
240 | return toAffineExpr(intConstant.getInt()); |
241 | if (auto blockArg = mlir::dyn_cast<mlir::BlockArgument>(value)) { |
242 | affineArgs.push_back(value); |
243 | if (isa<fir::DoLoopOp>(blockArg.getOwner()->getParentOp()) || |
244 | isa<mlir::affine::AffineForOp>(blockArg.getOwner()->getParentOp())) |
245 | return {mlir::getAffineDimExpr(dimCount++, value.getContext())}; |
246 | return {mlir::getAffineSymbolExpr(symCount++, value.getContext())}; |
247 | } |
248 | return {}; |
249 | } |
250 | |
251 | void fromCmpIOp(mlir::arith::CmpIOp cmpOp) { |
252 | auto lhsAffine = toAffineExpr(cmpOp.getLhs()); |
253 | auto rhsAffine = toAffineExpr(cmpOp.getRhs()); |
254 | if (!lhsAffine || !rhsAffine) |
255 | return; |
256 | auto constraintPair = |
257 | constraint(cmpOp.getPredicate(), *rhsAffine - *lhsAffine); |
258 | if (!constraintPair) |
259 | return; |
260 | integerSet = mlir::IntegerSet::get( |
261 | dimCount, symCount, {constraintPair->first}, {constraintPair->second}); |
262 | } |
263 | |
264 | std::optional<std::pair<AffineExpr, bool>> |
265 | constraint(mlir::arith::CmpIPredicate predicate, mlir::AffineExpr basic) { |
266 | switch (predicate) { |
267 | case mlir::arith::CmpIPredicate::slt: |
268 | return {std::make_pair(basic - 1, false)}; |
269 | case mlir::arith::CmpIPredicate::sle: |
270 | return {std::make_pair(basic, false)}; |
271 | case mlir::arith::CmpIPredicate::sgt: |
272 | return {std::make_pair(1 - basic, false)}; |
273 | case mlir::arith::CmpIPredicate::sge: |
274 | return {std::make_pair(0 - basic, false)}; |
275 | case mlir::arith::CmpIPredicate::eq: |
276 | return {std::make_pair(basic, true)}; |
277 | default: |
278 | return {}; |
279 | } |
280 | } |
281 | |
282 | llvm::SmallVector<mlir::Value> affineArgs; |
283 | std::optional<mlir::IntegerSet> integerSet; |
284 | mlir::Value firCondition; |
285 | unsigned symCount{0u}; |
286 | unsigned dimCount{0u}; |
287 | }; |
288 | } // namespace |
289 | |
290 | namespace { |
291 | /// Analysis for affine promotion of fir.if |
292 | struct AffineIfAnalysis { |
293 | AffineIfAnalysis() = default; |
294 | |
295 | explicit AffineIfAnalysis(fir::IfOp op, AffineFunctionAnalysis &afa) |
296 | : legality(analyzeIf(op, afa)) {} |
297 | |
298 | bool canPromoteToAffine() { return legality; } |
299 | |
300 | private: |
301 | bool analyzeIf(fir::IfOp op, AffineFunctionAnalysis &afa) { |
302 | if (op.getNumResults() == 0) |
303 | return true; |
304 | LLVM_DEBUG(llvm::dbgs() |
305 | << "AffineIfAnalysis: not promoting as op has results\n" ;); |
306 | return false; |
307 | } |
308 | |
309 | bool legality{}; |
310 | }; |
311 | } // namespace |
312 | |
313 | AffineIfAnalysis |
314 | AffineFunctionAnalysis::getChildIfAnalysis(fir::IfOp op) const { |
315 | auto it = ifAnalysisMap.find_as(op); |
316 | if (it == ifAnalysisMap.end()) { |
317 | LLVM_DEBUG(llvm::dbgs() << "AffineFunctionAnalysis: not computed for:\n" ; |
318 | op.dump();); |
319 | op.emitError("error in fetching if analysis in AffineFunctionAnalysis\n" ); |
320 | return {}; |
321 | } |
322 | return it->getSecond(); |
323 | } |
324 | |
325 | /// AffineMap rewriting fir.array_coor operation to affine apply, |
326 | /// %dim = fir.gendim %lowerBound, %upperBound, %stride |
327 | /// %a = fir.array_coor %arr(%dim) %i |
328 | /// returning affineMap = affine_map<(i)[lb, ub, st] -> (i*st - lb)> |
329 | static mlir::AffineMap createArrayIndexAffineMap(unsigned dimensions, |
330 | MLIRContext *context) { |
331 | auto index = mlir::getAffineConstantExpr(0, context); |
332 | auto accuExtent = mlir::getAffineConstantExpr(1, context); |
333 | for (unsigned i = 0; i < dimensions; ++i) { |
334 | mlir::AffineExpr idx = mlir::getAffineDimExpr(i, context), |
335 | lowerBound = mlir::getAffineSymbolExpr(i * 3, context), |
336 | currentExtent = |
337 | mlir::getAffineSymbolExpr(i * 3 + 1, context), |
338 | stride = mlir::getAffineSymbolExpr(i * 3 + 2, context), |
339 | currentPart = (idx * stride - lowerBound) * accuExtent; |
340 | index = currentPart + index; |
341 | accuExtent = accuExtent * currentExtent; |
342 | } |
343 | return mlir::AffineMap::get(dimensions, dimensions * 3, index); |
344 | } |
345 | |
346 | static std::optional<int64_t> constantIntegerLike(const mlir::Value value) { |
347 | if (auto definition = value.getDefiningOp<mlir::arith::ConstantOp>()) |
348 | if (auto stepAttr = mlir::dyn_cast<IntegerAttr>(definition.getValue())) |
349 | return stepAttr.getInt(); |
350 | return {}; |
351 | } |
352 | |
353 | static mlir::Type coordinateArrayElement(fir::ArrayCoorOp op) { |
354 | if (auto refType = |
355 | mlir::dyn_cast_or_null<ReferenceType>(op.getMemref().getType())) { |
356 | if (auto seqType = |
357 | mlir::dyn_cast_or_null<SequenceType>(refType.getEleTy())) { |
358 | return seqType.getEleTy(); |
359 | } |
360 | } |
361 | op.emitError( |
362 | "AffineLoopConversion: array type in coordinate operation not valid\n" ); |
363 | return mlir::Type(); |
364 | } |
365 | |
366 | static void populateIndexArgs(fir::ArrayCoorOp acoOp, fir::ShapeOp shape, |
367 | SmallVectorImpl<mlir::Value> &indexArgs, |
368 | mlir::PatternRewriter &rewriter) { |
369 | auto one = rewriter.create<mlir::arith::ConstantOp>( |
370 | acoOp.getLoc(), rewriter.getIndexType(), rewriter.getIndexAttr(1)); |
371 | auto extents = shape.getExtents(); |
372 | for (auto i = extents.begin(); i < extents.end(); i++) { |
373 | indexArgs.push_back(one); |
374 | indexArgs.push_back(*i); |
375 | indexArgs.push_back(one); |
376 | } |
377 | } |
378 | |
379 | static void populateIndexArgs(fir::ArrayCoorOp acoOp, fir::ShapeShiftOp shape, |
380 | SmallVectorImpl<mlir::Value> &indexArgs, |
381 | mlir::PatternRewriter &rewriter) { |
382 | auto one = rewriter.create<mlir::arith::ConstantOp>( |
383 | acoOp.getLoc(), rewriter.getIndexType(), rewriter.getIndexAttr(1)); |
384 | auto extents = shape.getPairs(); |
385 | for (auto i = extents.begin(); i < extents.end();) { |
386 | indexArgs.push_back(*i++); |
387 | indexArgs.push_back(*i++); |
388 | indexArgs.push_back(one); |
389 | } |
390 | } |
391 | |
392 | static void populateIndexArgs(fir::ArrayCoorOp acoOp, fir::SliceOp slice, |
393 | SmallVectorImpl<mlir::Value> &indexArgs, |
394 | mlir::PatternRewriter &rewriter) { |
395 | auto extents = slice.getTriples(); |
396 | for (auto i = extents.begin(); i < extents.end();) { |
397 | indexArgs.push_back(*i++); |
398 | indexArgs.push_back(*i++); |
399 | indexArgs.push_back(*i++); |
400 | } |
401 | } |
402 | |
403 | static void populateIndexArgs(fir::ArrayCoorOp acoOp, |
404 | SmallVectorImpl<mlir::Value> &indexArgs, |
405 | mlir::PatternRewriter &rewriter) { |
406 | if (auto shape = acoOp.getShape().getDefiningOp<ShapeOp>()) |
407 | return populateIndexArgs(acoOp, shape, indexArgs, rewriter); |
408 | if (auto shapeShift = acoOp.getShape().getDefiningOp<ShapeShiftOp>()) |
409 | return populateIndexArgs(acoOp, shapeShift, indexArgs, rewriter); |
410 | if (auto slice = acoOp.getShape().getDefiningOp<SliceOp>()) |
411 | return populateIndexArgs(acoOp, slice, indexArgs, rewriter); |
412 | } |
413 | |
414 | /// Returns affine.apply and fir.convert from array_coor and gendims |
415 | static std::pair<affine::AffineApplyOp, fir::ConvertOp> |
416 | createAffineOps(mlir::Value arrayRef, mlir::PatternRewriter &rewriter) { |
417 | auto acoOp = arrayRef.getDefiningOp<ArrayCoorOp>(); |
418 | auto affineMap = |
419 | createArrayIndexAffineMap(acoOp.getIndices().size(), acoOp.getContext()); |
420 | SmallVector<mlir::Value> indexArgs; |
421 | indexArgs.append(acoOp.getIndices().begin(), acoOp.getIndices().end()); |
422 | |
423 | populateIndexArgs(acoOp, indexArgs, rewriter); |
424 | |
425 | auto affineApply = rewriter.create<affine::AffineApplyOp>( |
426 | acoOp.getLoc(), affineMap, indexArgs); |
427 | auto arrayElementType = coordinateArrayElement(acoOp); |
428 | auto newType = |
429 | mlir::MemRefType::get({mlir::ShapedType::kDynamic}, arrayElementType); |
430 | auto arrayConvert = rewriter.create<fir::ConvertOp>(acoOp.getLoc(), newType, |
431 | acoOp.getMemref()); |
432 | return std::make_pair(affineApply, arrayConvert); |
433 | } |
434 | |
435 | static void rewriteLoad(fir::LoadOp loadOp, mlir::PatternRewriter &rewriter) { |
436 | rewriter.setInsertionPoint(loadOp); |
437 | auto affineOps = createAffineOps(loadOp.getMemref(), rewriter); |
438 | rewriter.replaceOpWithNewOp<affine::AffineLoadOp>( |
439 | loadOp, affineOps.second.getResult(), affineOps.first.getResult()); |
440 | } |
441 | |
442 | static void rewriteStore(fir::StoreOp storeOp, |
443 | mlir::PatternRewriter &rewriter) { |
444 | rewriter.setInsertionPoint(storeOp); |
445 | auto affineOps = createAffineOps(storeOp.getMemref(), rewriter); |
446 | rewriter.replaceOpWithNewOp<affine::AffineStoreOp>( |
447 | storeOp, storeOp.getValue(), affineOps.second.getResult(), |
448 | affineOps.first.getResult()); |
449 | } |
450 | |
451 | static void rewriteMemoryOps(Block *block, mlir::PatternRewriter &rewriter) { |
452 | for (auto &bodyOp : block->getOperations()) { |
453 | if (isa<fir::LoadOp>(bodyOp)) |
454 | rewriteLoad(cast<fir::LoadOp>(bodyOp), rewriter); |
455 | if (isa<fir::StoreOp>(bodyOp)) |
456 | rewriteStore(cast<fir::StoreOp>(bodyOp), rewriter); |
457 | } |
458 | } |
459 | |
460 | namespace { |
461 | /// Convert `fir.do_loop` to `affine.for`, creates fir.convert for arrays to |
462 | /// memref, rewrites array_coor to affine.apply with affine_map. Rewrites fir |
463 | /// loads and stores to affine. |
464 | class AffineLoopConversion : public mlir::OpRewritePattern<fir::DoLoopOp> { |
465 | public: |
466 | using OpRewritePattern::OpRewritePattern; |
467 | AffineLoopConversion(mlir::MLIRContext *context, AffineFunctionAnalysis &afa) |
468 | : OpRewritePattern(context), functionAnalysis(afa) {} |
469 | |
470 | llvm::LogicalResult |
471 | matchAndRewrite(fir::DoLoopOp loop, |
472 | mlir::PatternRewriter &rewriter) const override { |
473 | LLVM_DEBUG(llvm::dbgs() << "AffineLoopConversion: rewriting loop:\n" ; |
474 | loop.dump();); |
475 | LLVM_ATTRIBUTE_UNUSED auto loopAnalysis = |
476 | functionAnalysis.getChildLoopAnalysis(loop); |
477 | auto &loopOps = loop.getBody()->getOperations(); |
478 | auto resultOp = cast<fir::ResultOp>(loop.getBody()->getTerminator()); |
479 | auto results = resultOp.getOperands(); |
480 | auto loopResults = loop->getResults(); |
481 | auto loopAndIndex = createAffineFor(loop, rewriter); |
482 | auto affineFor = loopAndIndex.first; |
483 | auto inductionVar = loopAndIndex.second; |
484 | |
485 | if (loop.getFinalValue()) { |
486 | results = results.drop_front(); |
487 | loopResults = loopResults.drop_front(); |
488 | } |
489 | |
490 | rewriter.startOpModification(affineFor.getOperation()); |
491 | affineFor.getBody()->getOperations().splice( |
492 | std::prev(affineFor.getBody()->end()), loopOps, loopOps.begin(), |
493 | std::prev(loopOps.end())); |
494 | rewriter.replaceAllUsesWith(loop.getRegionIterArgs(), |
495 | affineFor.getRegionIterArgs()); |
496 | if (!results.empty()) { |
497 | rewriter.setInsertionPointToEnd(affineFor.getBody()); |
498 | rewriter.create<affine::AffineYieldOp>(resultOp->getLoc(), results); |
499 | } |
500 | rewriter.finalizeOpModification(affineFor.getOperation()); |
501 | |
502 | rewriter.startOpModification(loop.getOperation()); |
503 | loop.getInductionVar().replaceAllUsesWith(inductionVar); |
504 | rewriter.finalizeOpModification(loop.getOperation()); |
505 | |
506 | rewriteMemoryOps(affineFor.getBody(), rewriter); |
507 | |
508 | LLVM_DEBUG(llvm::dbgs() << "AffineLoopConversion: loop rewriten to:\n" ; |
509 | affineFor.dump();); |
510 | rewriter.replaceAllUsesWith(loopResults, affineFor->getResults()); |
511 | rewriter.eraseOp(loop); |
512 | return success(); |
513 | } |
514 | |
515 | private: |
516 | std::pair<affine::AffineForOp, mlir::Value> |
517 | createAffineFor(fir::DoLoopOp op, mlir::PatternRewriter &rewriter) const { |
518 | if (auto constantStep = constantIntegerLike(op.getStep())) |
519 | if (*constantStep > 0) |
520 | return positiveConstantStep(op, *constantStep, rewriter); |
521 | return genericBounds(op, rewriter); |
522 | } |
523 | |
524 | // when step for the loop is positive compile time constant |
525 | std::pair<affine::AffineForOp, mlir::Value> |
526 | positiveConstantStep(fir::DoLoopOp op, int64_t step, |
527 | mlir::PatternRewriter &rewriter) const { |
528 | auto affineFor = rewriter.create<affine::AffineForOp>( |
529 | op.getLoc(), ValueRange(op.getLowerBound()), |
530 | mlir::AffineMap::get(0, 1, |
531 | mlir::getAffineSymbolExpr(0, op.getContext())), |
532 | ValueRange(op.getUpperBound()), |
533 | mlir::AffineMap::get(0, 1, |
534 | 1 + mlir::getAffineSymbolExpr(0, op.getContext())), |
535 | step, op.getIterOperands()); |
536 | return std::make_pair(affineFor, affineFor.getInductionVar()); |
537 | } |
538 | |
539 | std::pair<affine::AffineForOp, mlir::Value> |
540 | genericBounds(fir::DoLoopOp op, mlir::PatternRewriter &rewriter) const { |
541 | auto lowerBound = mlir::getAffineSymbolExpr(0, op.getContext()); |
542 | auto upperBound = mlir::getAffineSymbolExpr(1, op.getContext()); |
543 | auto step = mlir::getAffineSymbolExpr(2, op.getContext()); |
544 | mlir::AffineMap upperBoundMap = mlir::AffineMap::get( |
545 | 0, 3, (upperBound - lowerBound + step).floorDiv(step)); |
546 | auto genericUpperBound = rewriter.create<affine::AffineApplyOp>( |
547 | op.getLoc(), upperBoundMap, |
548 | ValueRange({op.getLowerBound(), op.getUpperBound(), op.getStep()})); |
549 | auto actualIndexMap = mlir::AffineMap::get( |
550 | 1, 2, |
551 | (lowerBound + mlir::getAffineDimExpr(0, op.getContext())) * |
552 | mlir::getAffineSymbolExpr(1, op.getContext())); |
553 | |
554 | auto affineFor = rewriter.create<affine::AffineForOp>( |
555 | op.getLoc(), ValueRange(), |
556 | AffineMap::getConstantMap(0, op.getContext()), |
557 | genericUpperBound.getResult(), |
558 | mlir::AffineMap::get(0, 1, |
559 | 1 + mlir::getAffineSymbolExpr(0, op.getContext())), |
560 | 1, op.getIterOperands()); |
561 | rewriter.setInsertionPointToStart(affineFor.getBody()); |
562 | auto actualIndex = rewriter.create<affine::AffineApplyOp>( |
563 | op.getLoc(), actualIndexMap, |
564 | ValueRange( |
565 | {affineFor.getInductionVar(), op.getLowerBound(), op.getStep()})); |
566 | return std::make_pair(affineFor, actualIndex.getResult()); |
567 | } |
568 | |
569 | AffineFunctionAnalysis &functionAnalysis; |
570 | }; |
571 | |
572 | /// Convert `fir.if` to `affine.if`. |
573 | class AffineIfConversion : public mlir::OpRewritePattern<fir::IfOp> { |
574 | public: |
575 | using OpRewritePattern::OpRewritePattern; |
576 | AffineIfConversion(mlir::MLIRContext *context, AffineFunctionAnalysis &afa) |
577 | : OpRewritePattern(context) {} |
578 | llvm::LogicalResult |
579 | matchAndRewrite(fir::IfOp op, |
580 | mlir::PatternRewriter &rewriter) const override { |
581 | LLVM_DEBUG(llvm::dbgs() << "AffineIfConversion: rewriting if:\n" ; |
582 | op.dump();); |
583 | auto &ifOps = op.getThenRegion().front().getOperations(); |
584 | auto affineCondition = AffineIfCondition(op.getCondition()); |
585 | if (!affineCondition.hasIntegerSet()) { |
586 | LLVM_DEBUG( |
587 | llvm::dbgs() |
588 | << "AffineIfConversion: couldn't calculate affine condition\n" ;); |
589 | return failure(); |
590 | } |
591 | auto affineIf = rewriter.create<affine::AffineIfOp>( |
592 | op.getLoc(), affineCondition.getIntegerSet(), |
593 | affineCondition.getAffineArgs(), !op.getElseRegion().empty()); |
594 | rewriter.startOpModification(affineIf); |
595 | affineIf.getThenBlock()->getOperations().splice( |
596 | std::prev(affineIf.getThenBlock()->end()), ifOps, ifOps.begin(), |
597 | std::prev(ifOps.end())); |
598 | if (!op.getElseRegion().empty()) { |
599 | auto &otherOps = op.getElseRegion().front().getOperations(); |
600 | affineIf.getElseBlock()->getOperations().splice( |
601 | std::prev(affineIf.getElseBlock()->end()), otherOps, otherOps.begin(), |
602 | std::prev(otherOps.end())); |
603 | } |
604 | rewriter.finalizeOpModification(affineIf); |
605 | rewriteMemoryOps(affineIf.getBody(), rewriter); |
606 | |
607 | LLVM_DEBUG(llvm::dbgs() << "AffineIfConversion: if converted to:\n" ; |
608 | affineIf.dump();); |
609 | rewriter.replaceOp(op, affineIf.getOperation()->getResults()); |
610 | return success(); |
611 | } |
612 | }; |
613 | |
614 | /// Promote fir.do_loop and fir.if to affine.for and affine.if, in the cases |
615 | /// where such a promotion is possible. |
616 | class AffineDialectPromotion |
617 | : public fir::impl::AffineDialectPromotionBase<AffineDialectPromotion> { |
618 | public: |
619 | void runOnOperation() override { |
620 | |
621 | auto *context = &getContext(); |
622 | auto function = getOperation(); |
623 | markAllAnalysesPreserved(); |
624 | auto functionAnalysis = AffineFunctionAnalysis(function); |
625 | mlir::RewritePatternSet patterns(context); |
626 | patterns.insert<AffineIfConversion>(context, functionAnalysis); |
627 | patterns.insert<AffineLoopConversion>(context, functionAnalysis); |
628 | mlir::ConversionTarget target = *context; |
629 | target.addLegalDialect<mlir::affine::AffineDialect, FIROpsDialect, |
630 | mlir::scf::SCFDialect, mlir::arith::ArithDialect, |
631 | mlir::func::FuncDialect>(); |
632 | target.addDynamicallyLegalOp<IfOp>([&functionAnalysis](fir::IfOp op) { |
633 | return !(functionAnalysis.getChildIfAnalysis(op).canPromoteToAffine()); |
634 | }); |
635 | target.addDynamicallyLegalOp<DoLoopOp>([&functionAnalysis]( |
636 | fir::DoLoopOp op) { |
637 | return !(functionAnalysis.getChildLoopAnalysis(op).canPromoteToAffine()); |
638 | }); |
639 | |
640 | LLVM_DEBUG(llvm::dbgs() |
641 | << "AffineDialectPromotion: running promotion on: \n" ; |
642 | function.print(llvm::dbgs());); |
643 | // apply the patterns |
644 | if (mlir::failed(mlir::applyPartialConversion(function, target, |
645 | std::move(patterns)))) { |
646 | mlir::emitError(mlir::UnknownLoc::get(context), |
647 | "error in converting to affine dialect\n" ); |
648 | signalPassFailure(); |
649 | } |
650 | } |
651 | }; |
652 | } // namespace |
653 | |
654 | /// Convert FIR loop constructs to the Affine dialect |
655 | std::unique_ptr<mlir::Pass> fir::createPromoteToAffinePass() { |
656 | return std::make_unique<AffineDialectPromotion>(); |
657 | } |
658 | |