1 | //===-- AffinePromotion.cpp -----------------------------------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // This transformation is a prototype that promote FIR loops operations |
10 | // to affine dialect operations. |
11 | // It is not part of the production pipeline and would need more work in order |
12 | // to be used in production. |
13 | // More information can be found in this presentation: |
14 | // https://slides.com/rajanwalia/deck |
15 | // |
16 | //===----------------------------------------------------------------------===// |
17 | |
18 | #include "flang/Optimizer/Dialect/FIRDialect.h" |
19 | #include "flang/Optimizer/Dialect/FIROps.h" |
20 | #include "flang/Optimizer/Dialect/FIRType.h" |
21 | #include "flang/Optimizer/Transforms/Passes.h" |
22 | #include "mlir/Dialect/Affine/IR/AffineOps.h" |
23 | #include "mlir/Dialect/Func/IR/FuncOps.h" |
24 | #include "mlir/Dialect/SCF/IR/SCF.h" |
25 | #include "mlir/IR/BuiltinAttributes.h" |
26 | #include "mlir/IR/IntegerSet.h" |
27 | #include "mlir/IR/Visitors.h" |
28 | #include "mlir/Transforms/DialectConversion.h" |
29 | #include "llvm/ADT/DenseMap.h" |
30 | #include "llvm/Support/Debug.h" |
31 | #include <optional> |
32 | |
33 | namespace fir { |
34 | #define GEN_PASS_DEF_AFFINEDIALECTPROMOTION |
35 | #include "flang/Optimizer/Transforms/Passes.h.inc" |
36 | } // namespace fir |
37 | |
38 | #define DEBUG_TYPE "flang-affine-promotion" |
39 | |
40 | using namespace fir; |
41 | using namespace mlir; |
42 | |
43 | namespace { |
44 | struct AffineLoopAnalysis; |
45 | struct AffineIfAnalysis; |
46 | |
47 | /// Stores analysis objects for all loops and if operations inside a function |
48 | /// these analysis are used twice, first for marking operations for rewrite and |
49 | /// second when doing rewrite. |
50 | struct AffineFunctionAnalysis { |
51 | explicit AffineFunctionAnalysis(mlir::func::FuncOp funcOp) { |
52 | for (fir::DoLoopOp op : funcOp.getOps<fir::DoLoopOp>()) |
53 | loopAnalysisMap.try_emplace(op, op, *this); |
54 | } |
55 | |
56 | AffineLoopAnalysis getChildLoopAnalysis(fir::DoLoopOp op) const; |
57 | |
58 | AffineIfAnalysis getChildIfAnalysis(fir::IfOp op) const; |
59 | |
60 | llvm::DenseMap<mlir::Operation *, AffineLoopAnalysis> loopAnalysisMap; |
61 | llvm::DenseMap<mlir::Operation *, AffineIfAnalysis> ifAnalysisMap; |
62 | }; |
63 | } // namespace |
64 | |
65 | static bool analyzeCoordinate(mlir::Value coordinate, mlir::Operation *op) { |
66 | if (auto blockArg = mlir::dyn_cast<mlir::BlockArgument>(coordinate)) { |
67 | if (isa<fir::DoLoopOp>(blockArg.getOwner()->getParentOp())) |
68 | return true; |
69 | LLVM_DEBUG(llvm::dbgs() << "AffineLoopAnalysis: array coordinate is not a " |
70 | "loop induction variable (owner not loopOp)\n" ; |
71 | op->dump()); |
72 | return false; |
73 | } |
74 | LLVM_DEBUG( |
75 | llvm::dbgs() << "AffineLoopAnalysis: array coordinate is not a loop " |
76 | "induction variable (not a block argument)\n" ; |
77 | op->dump(); coordinate.getDefiningOp()->dump()); |
78 | return false; |
79 | } |
80 | |
81 | namespace { |
82 | struct AffineLoopAnalysis { |
83 | AffineLoopAnalysis() = default; |
84 | |
85 | explicit AffineLoopAnalysis(fir::DoLoopOp op, AffineFunctionAnalysis &afa) |
86 | : legality(analyzeLoop(op, afa)) {} |
87 | |
88 | bool canPromoteToAffine() { return legality; } |
89 | |
90 | private: |
91 | bool analyzeBody(fir::DoLoopOp loopOperation, |
92 | AffineFunctionAnalysis &functionAnalysis) { |
93 | for (auto loopOp : loopOperation.getOps<fir::DoLoopOp>()) { |
94 | auto analysis = functionAnalysis.loopAnalysisMap |
95 | .try_emplace(loopOp, loopOp, functionAnalysis) |
96 | .first->getSecond(); |
97 | if (!analysis.canPromoteToAffine()) |
98 | return false; |
99 | } |
100 | for (auto ifOp : loopOperation.getOps<fir::IfOp>()) |
101 | functionAnalysis.ifAnalysisMap.try_emplace(ifOp, ifOp, functionAnalysis); |
102 | return true; |
103 | } |
104 | |
105 | bool analyzeLoop(fir::DoLoopOp loopOperation, |
106 | AffineFunctionAnalysis &functionAnalysis) { |
107 | LLVM_DEBUG(llvm::dbgs() << "AffineLoopAnalysis: \n" ; loopOperation.dump();); |
108 | return analyzeMemoryAccess(loopOperation) && |
109 | analyzeBody(loopOperation, functionAnalysis); |
110 | } |
111 | |
112 | bool analyzeReference(mlir::Value memref, mlir::Operation *op) { |
113 | if (auto acoOp = memref.getDefiningOp<ArrayCoorOp>()) { |
114 | if (acoOp.getMemref().getType().isa<fir::BoxType>()) { |
115 | // TODO: Look if and how fir.box can be promoted to affine. |
116 | LLVM_DEBUG(llvm::dbgs() << "AffineLoopAnalysis: cannot promote loop, " |
117 | "array memory operation uses fir.box\n" ; |
118 | op->dump(); acoOp.dump();); |
119 | return false; |
120 | } |
121 | bool canPromote = true; |
122 | for (auto coordinate : acoOp.getIndices()) |
123 | canPromote = canPromote && analyzeCoordinate(coordinate, op); |
124 | return canPromote; |
125 | } |
126 | if (auto coOp = memref.getDefiningOp<CoordinateOp>()) { |
127 | LLVM_DEBUG(llvm::dbgs() |
128 | << "AffineLoopAnalysis: cannot promote loop, " |
129 | "array memory operation uses non ArrayCoorOp\n" ; |
130 | op->dump(); coOp.dump();); |
131 | |
132 | return false; |
133 | } |
134 | LLVM_DEBUG(llvm::dbgs() << "AffineLoopAnalysis: unknown type of memory " |
135 | "reference for array load\n" ; |
136 | op->dump();); |
137 | return false; |
138 | } |
139 | |
140 | bool analyzeMemoryAccess(fir::DoLoopOp loopOperation) { |
141 | for (auto loadOp : loopOperation.getOps<fir::LoadOp>()) |
142 | if (!analyzeReference(loadOp.getMemref(), loadOp)) |
143 | return false; |
144 | for (auto storeOp : loopOperation.getOps<fir::StoreOp>()) |
145 | if (!analyzeReference(storeOp.getMemref(), storeOp)) |
146 | return false; |
147 | return true; |
148 | } |
149 | |
150 | bool legality{}; |
151 | }; |
152 | } // namespace |
153 | |
154 | AffineLoopAnalysis |
155 | AffineFunctionAnalysis::getChildLoopAnalysis(fir::DoLoopOp op) const { |
156 | auto it = loopAnalysisMap.find_as(op); |
157 | if (it == loopAnalysisMap.end()) { |
158 | LLVM_DEBUG(llvm::dbgs() << "AffineFunctionAnalysis: not computed for:\n" ; |
159 | op.dump();); |
160 | op.emitError("error in fetching loop analysis in AffineFunctionAnalysis\n" ); |
161 | return {}; |
162 | } |
163 | return it->getSecond(); |
164 | } |
165 | |
166 | namespace { |
167 | /// Calculates arguments for creating an IntegerSet. symCount, dimCount are the |
168 | /// final number of symbols and dimensions of the affine map. Integer set if |
169 | /// possible is in Optional IntegerSet. |
170 | struct AffineIfCondition { |
171 | using MaybeAffineExpr = std::optional<mlir::AffineExpr>; |
172 | |
173 | explicit AffineIfCondition(mlir::Value fc) : firCondition(fc) { |
174 | if (auto condDef = firCondition.getDefiningOp<mlir::arith::CmpIOp>()) |
175 | fromCmpIOp(condDef); |
176 | } |
177 | |
178 | bool hasIntegerSet() const { return integerSet.has_value(); } |
179 | |
180 | mlir::IntegerSet getIntegerSet() const { |
181 | assert(hasIntegerSet() && "integer set is missing" ); |
182 | return *integerSet; |
183 | } |
184 | |
185 | mlir::ValueRange getAffineArgs() const { return affineArgs; } |
186 | |
187 | private: |
188 | MaybeAffineExpr affineBinaryOp(mlir::AffineExprKind kind, mlir::Value lhs, |
189 | mlir::Value rhs) { |
190 | return affineBinaryOp(kind, toAffineExpr(lhs), toAffineExpr(rhs)); |
191 | } |
192 | |
193 | MaybeAffineExpr affineBinaryOp(mlir::AffineExprKind kind, MaybeAffineExpr lhs, |
194 | MaybeAffineExpr rhs) { |
195 | if (lhs && rhs) |
196 | return mlir::getAffineBinaryOpExpr(kind, *lhs, *rhs); |
197 | return {}; |
198 | } |
199 | |
200 | MaybeAffineExpr toAffineExpr(MaybeAffineExpr e) { return e; } |
201 | |
202 | MaybeAffineExpr toAffineExpr(int64_t value) { |
203 | return {mlir::getAffineConstantExpr(value, firCondition.getContext())}; |
204 | } |
205 | |
206 | /// Returns an AffineExpr if it is a result of operations that can be done |
207 | /// in an affine expression, this includes -, +, *, rem, constant. |
208 | /// block arguments of a loopOp or forOp are used as dimensions |
209 | MaybeAffineExpr toAffineExpr(mlir::Value value) { |
210 | if (auto op = value.getDefiningOp<mlir::arith::SubIOp>()) |
211 | return affineBinaryOp( |
212 | mlir::AffineExprKind::Add, toAffineExpr(op.getLhs()), |
213 | affineBinaryOp(mlir::AffineExprKind::Mul, toAffineExpr(op.getRhs()), |
214 | toAffineExpr(-1))); |
215 | if (auto op = value.getDefiningOp<mlir::arith::AddIOp>()) |
216 | return affineBinaryOp(mlir::AffineExprKind::Add, op.getLhs(), |
217 | op.getRhs()); |
218 | if (auto op = value.getDefiningOp<mlir::arith::MulIOp>()) |
219 | return affineBinaryOp(mlir::AffineExprKind::Mul, op.getLhs(), |
220 | op.getRhs()); |
221 | if (auto op = value.getDefiningOp<mlir::arith::RemUIOp>()) |
222 | return affineBinaryOp(mlir::AffineExprKind::Mod, op.getLhs(), |
223 | op.getRhs()); |
224 | if (auto op = value.getDefiningOp<mlir::arith::ConstantOp>()) |
225 | if (auto intConstant = op.getValue().dyn_cast<IntegerAttr>()) |
226 | return toAffineExpr(intConstant.getInt()); |
227 | if (auto blockArg = mlir::dyn_cast<mlir::BlockArgument>(value)) { |
228 | affineArgs.push_back(value); |
229 | if (isa<fir::DoLoopOp>(blockArg.getOwner()->getParentOp()) || |
230 | isa<mlir::affine::AffineForOp>(blockArg.getOwner()->getParentOp())) |
231 | return {mlir::getAffineDimExpr(dimCount++, value.getContext())}; |
232 | return {mlir::getAffineSymbolExpr(symCount++, value.getContext())}; |
233 | } |
234 | return {}; |
235 | } |
236 | |
237 | void fromCmpIOp(mlir::arith::CmpIOp cmpOp) { |
238 | auto lhsAffine = toAffineExpr(cmpOp.getLhs()); |
239 | auto rhsAffine = toAffineExpr(cmpOp.getRhs()); |
240 | if (!lhsAffine || !rhsAffine) |
241 | return; |
242 | auto constraintPair = |
243 | constraint(cmpOp.getPredicate(), *rhsAffine - *lhsAffine); |
244 | if (!constraintPair) |
245 | return; |
246 | integerSet = mlir::IntegerSet::get( |
247 | dimCount, symCount, {constraintPair->first}, {constraintPair->second}); |
248 | } |
249 | |
250 | std::optional<std::pair<AffineExpr, bool>> |
251 | constraint(mlir::arith::CmpIPredicate predicate, mlir::AffineExpr basic) { |
252 | switch (predicate) { |
253 | case mlir::arith::CmpIPredicate::slt: |
254 | return {std::make_pair(basic - 1, false)}; |
255 | case mlir::arith::CmpIPredicate::sle: |
256 | return {std::make_pair(basic, false)}; |
257 | case mlir::arith::CmpIPredicate::sgt: |
258 | return {std::make_pair(1 - basic, false)}; |
259 | case mlir::arith::CmpIPredicate::sge: |
260 | return {std::make_pair(0 - basic, false)}; |
261 | case mlir::arith::CmpIPredicate::eq: |
262 | return {std::make_pair(basic, true)}; |
263 | default: |
264 | return {}; |
265 | } |
266 | } |
267 | |
268 | llvm::SmallVector<mlir::Value> affineArgs; |
269 | std::optional<mlir::IntegerSet> integerSet; |
270 | mlir::Value firCondition; |
271 | unsigned symCount{0u}; |
272 | unsigned dimCount{0u}; |
273 | }; |
274 | } // namespace |
275 | |
276 | namespace { |
277 | /// Analysis for affine promotion of fir.if |
278 | struct AffineIfAnalysis { |
279 | AffineIfAnalysis() = default; |
280 | |
281 | explicit AffineIfAnalysis(fir::IfOp op, AffineFunctionAnalysis &afa) |
282 | : legality(analyzeIf(op, afa)) {} |
283 | |
284 | bool canPromoteToAffine() { return legality; } |
285 | |
286 | private: |
287 | bool analyzeIf(fir::IfOp op, AffineFunctionAnalysis &afa) { |
288 | if (op.getNumResults() == 0) |
289 | return true; |
290 | LLVM_DEBUG(llvm::dbgs() |
291 | << "AffineIfAnalysis: not promoting as op has results\n" ;); |
292 | return false; |
293 | } |
294 | |
295 | bool legality{}; |
296 | }; |
297 | } // namespace |
298 | |
299 | AffineIfAnalysis |
300 | AffineFunctionAnalysis::getChildIfAnalysis(fir::IfOp op) const { |
301 | auto it = ifAnalysisMap.find_as(op); |
302 | if (it == ifAnalysisMap.end()) { |
303 | LLVM_DEBUG(llvm::dbgs() << "AffineFunctionAnalysis: not computed for:\n" ; |
304 | op.dump();); |
305 | op.emitError("error in fetching if analysis in AffineFunctionAnalysis\n" ); |
306 | return {}; |
307 | } |
308 | return it->getSecond(); |
309 | } |
310 | |
311 | /// AffineMap rewriting fir.array_coor operation to affine apply, |
312 | /// %dim = fir.gendim %lowerBound, %upperBound, %stride |
313 | /// %a = fir.array_coor %arr(%dim) %i |
314 | /// returning affineMap = affine_map<(i)[lb, ub, st] -> (i*st - lb)> |
315 | static mlir::AffineMap createArrayIndexAffineMap(unsigned dimensions, |
316 | MLIRContext *context) { |
317 | auto index = mlir::getAffineConstantExpr(0, context); |
318 | auto accuExtent = mlir::getAffineConstantExpr(1, context); |
319 | for (unsigned i = 0; i < dimensions; ++i) { |
320 | mlir::AffineExpr idx = mlir::getAffineDimExpr(i, context), |
321 | lowerBound = mlir::getAffineSymbolExpr(i * 3, context), |
322 | currentExtent = |
323 | mlir::getAffineSymbolExpr(i * 3 + 1, context), |
324 | stride = mlir::getAffineSymbolExpr(i * 3 + 2, context), |
325 | currentPart = (idx * stride - lowerBound) * accuExtent; |
326 | index = currentPart + index; |
327 | accuExtent = accuExtent * currentExtent; |
328 | } |
329 | return mlir::AffineMap::get(dimensions, dimensions * 3, index); |
330 | } |
331 | |
332 | static std::optional<int64_t> constantIntegerLike(const mlir::Value value) { |
333 | if (auto definition = value.getDefiningOp<mlir::arith::ConstantOp>()) |
334 | if (auto stepAttr = definition.getValue().dyn_cast<IntegerAttr>()) |
335 | return stepAttr.getInt(); |
336 | return {}; |
337 | } |
338 | |
339 | static mlir::Type coordinateArrayElement(fir::ArrayCoorOp op) { |
340 | if (auto refType = |
341 | op.getMemref().getType().dyn_cast_or_null<ReferenceType>()) { |
342 | if (auto seqType = refType.getEleTy().dyn_cast_or_null<SequenceType>()) { |
343 | return seqType.getEleTy(); |
344 | } |
345 | } |
346 | op.emitError( |
347 | "AffineLoopConversion: array type in coordinate operation not valid\n" ); |
348 | return mlir::Type(); |
349 | } |
350 | |
351 | static void populateIndexArgs(fir::ArrayCoorOp acoOp, fir::ShapeOp shape, |
352 | SmallVectorImpl<mlir::Value> &indexArgs, |
353 | mlir::PatternRewriter &rewriter) { |
354 | auto one = rewriter.create<mlir::arith::ConstantOp>( |
355 | acoOp.getLoc(), rewriter.getIndexType(), rewriter.getIndexAttr(1)); |
356 | auto extents = shape.getExtents(); |
357 | for (auto i = extents.begin(); i < extents.end(); i++) { |
358 | indexArgs.push_back(one); |
359 | indexArgs.push_back(*i); |
360 | indexArgs.push_back(one); |
361 | } |
362 | } |
363 | |
364 | static void populateIndexArgs(fir::ArrayCoorOp acoOp, fir::ShapeShiftOp shape, |
365 | SmallVectorImpl<mlir::Value> &indexArgs, |
366 | mlir::PatternRewriter &rewriter) { |
367 | auto one = rewriter.create<mlir::arith::ConstantOp>( |
368 | acoOp.getLoc(), rewriter.getIndexType(), rewriter.getIndexAttr(1)); |
369 | auto extents = shape.getPairs(); |
370 | for (auto i = extents.begin(); i < extents.end();) { |
371 | indexArgs.push_back(*i++); |
372 | indexArgs.push_back(*i++); |
373 | indexArgs.push_back(one); |
374 | } |
375 | } |
376 | |
377 | static void populateIndexArgs(fir::ArrayCoorOp acoOp, fir::SliceOp slice, |
378 | SmallVectorImpl<mlir::Value> &indexArgs, |
379 | mlir::PatternRewriter &rewriter) { |
380 | auto extents = slice.getTriples(); |
381 | for (auto i = extents.begin(); i < extents.end();) { |
382 | indexArgs.push_back(*i++); |
383 | indexArgs.push_back(*i++); |
384 | indexArgs.push_back(*i++); |
385 | } |
386 | } |
387 | |
388 | static void populateIndexArgs(fir::ArrayCoorOp acoOp, |
389 | SmallVectorImpl<mlir::Value> &indexArgs, |
390 | mlir::PatternRewriter &rewriter) { |
391 | if (auto shape = acoOp.getShape().getDefiningOp<ShapeOp>()) |
392 | return populateIndexArgs(acoOp, shape, indexArgs, rewriter); |
393 | if (auto shapeShift = acoOp.getShape().getDefiningOp<ShapeShiftOp>()) |
394 | return populateIndexArgs(acoOp, shapeShift, indexArgs, rewriter); |
395 | if (auto slice = acoOp.getShape().getDefiningOp<SliceOp>()) |
396 | return populateIndexArgs(acoOp, slice, indexArgs, rewriter); |
397 | } |
398 | |
399 | /// Returns affine.apply and fir.convert from array_coor and gendims |
400 | static std::pair<affine::AffineApplyOp, fir::ConvertOp> |
401 | createAffineOps(mlir::Value arrayRef, mlir::PatternRewriter &rewriter) { |
402 | auto acoOp = arrayRef.getDefiningOp<ArrayCoorOp>(); |
403 | auto affineMap = |
404 | createArrayIndexAffineMap(acoOp.getIndices().size(), acoOp.getContext()); |
405 | SmallVector<mlir::Value> indexArgs; |
406 | indexArgs.append(acoOp.getIndices().begin(), acoOp.getIndices().end()); |
407 | |
408 | populateIndexArgs(acoOp, indexArgs, rewriter); |
409 | |
410 | auto affineApply = rewriter.create<affine::AffineApplyOp>( |
411 | acoOp.getLoc(), affineMap, indexArgs); |
412 | auto arrayElementType = coordinateArrayElement(acoOp); |
413 | auto newType = |
414 | mlir::MemRefType::get({mlir::ShapedType::kDynamic}, arrayElementType); |
415 | auto arrayConvert = rewriter.create<fir::ConvertOp>(acoOp.getLoc(), newType, |
416 | acoOp.getMemref()); |
417 | return std::make_pair(affineApply, arrayConvert); |
418 | } |
419 | |
420 | static void rewriteLoad(fir::LoadOp loadOp, mlir::PatternRewriter &rewriter) { |
421 | rewriter.setInsertionPoint(loadOp); |
422 | auto affineOps = createAffineOps(loadOp.getMemref(), rewriter); |
423 | rewriter.replaceOpWithNewOp<affine::AffineLoadOp>( |
424 | loadOp, affineOps.second.getResult(), affineOps.first.getResult()); |
425 | } |
426 | |
427 | static void rewriteStore(fir::StoreOp storeOp, |
428 | mlir::PatternRewriter &rewriter) { |
429 | rewriter.setInsertionPoint(storeOp); |
430 | auto affineOps = createAffineOps(storeOp.getMemref(), rewriter); |
431 | rewriter.replaceOpWithNewOp<affine::AffineStoreOp>( |
432 | storeOp, storeOp.getValue(), affineOps.second.getResult(), |
433 | affineOps.first.getResult()); |
434 | } |
435 | |
436 | static void rewriteMemoryOps(Block *block, mlir::PatternRewriter &rewriter) { |
437 | for (auto &bodyOp : block->getOperations()) { |
438 | if (isa<fir::LoadOp>(bodyOp)) |
439 | rewriteLoad(cast<fir::LoadOp>(bodyOp), rewriter); |
440 | if (isa<fir::StoreOp>(bodyOp)) |
441 | rewriteStore(cast<fir::StoreOp>(bodyOp), rewriter); |
442 | } |
443 | } |
444 | |
445 | namespace { |
446 | /// Convert `fir.do_loop` to `affine.for`, creates fir.convert for arrays to |
447 | /// memref, rewrites array_coor to affine.apply with affine_map. Rewrites fir |
448 | /// loads and stores to affine. |
449 | class AffineLoopConversion : public mlir::OpRewritePattern<fir::DoLoopOp> { |
450 | public: |
451 | using OpRewritePattern::OpRewritePattern; |
452 | AffineLoopConversion(mlir::MLIRContext *context, AffineFunctionAnalysis &afa) |
453 | : OpRewritePattern(context), functionAnalysis(afa) {} |
454 | |
455 | mlir::LogicalResult |
456 | matchAndRewrite(fir::DoLoopOp loop, |
457 | mlir::PatternRewriter &rewriter) const override { |
458 | LLVM_DEBUG(llvm::dbgs() << "AffineLoopConversion: rewriting loop:\n" ; |
459 | loop.dump();); |
460 | LLVM_ATTRIBUTE_UNUSED auto loopAnalysis = |
461 | functionAnalysis.getChildLoopAnalysis(loop); |
462 | auto &loopOps = loop.getBody()->getOperations(); |
463 | auto loopAndIndex = createAffineFor(loop, rewriter); |
464 | auto affineFor = loopAndIndex.first; |
465 | auto inductionVar = loopAndIndex.second; |
466 | |
467 | rewriter.startOpModification(affineFor.getOperation()); |
468 | affineFor.getBody()->getOperations().splice( |
469 | std::prev(affineFor.getBody()->end()), loopOps, loopOps.begin(), |
470 | std::prev(loopOps.end())); |
471 | rewriter.finalizeOpModification(affineFor.getOperation()); |
472 | |
473 | rewriter.startOpModification(loop.getOperation()); |
474 | loop.getInductionVar().replaceAllUsesWith(inductionVar); |
475 | rewriter.finalizeOpModification(loop.getOperation()); |
476 | |
477 | rewriteMemoryOps(affineFor.getBody(), rewriter); |
478 | |
479 | LLVM_DEBUG(llvm::dbgs() << "AffineLoopConversion: loop rewriten to:\n" ; |
480 | affineFor.dump();); |
481 | rewriter.replaceOp(loop, affineFor.getOperation()->getResults()); |
482 | return success(); |
483 | } |
484 | |
485 | private: |
486 | std::pair<affine::AffineForOp, mlir::Value> |
487 | createAffineFor(fir::DoLoopOp op, mlir::PatternRewriter &rewriter) const { |
488 | if (auto constantStep = constantIntegerLike(op.getStep())) |
489 | if (*constantStep > 0) |
490 | return positiveConstantStep(op, *constantStep, rewriter); |
491 | return genericBounds(op, rewriter); |
492 | } |
493 | |
494 | // when step for the loop is positive compile time constant |
495 | std::pair<affine::AffineForOp, mlir::Value> |
496 | positiveConstantStep(fir::DoLoopOp op, int64_t step, |
497 | mlir::PatternRewriter &rewriter) const { |
498 | auto affineFor = rewriter.create<affine::AffineForOp>( |
499 | op.getLoc(), ValueRange(op.getLowerBound()), |
500 | mlir::AffineMap::get(0, 1, |
501 | mlir::getAffineSymbolExpr(0, op.getContext())), |
502 | ValueRange(op.getUpperBound()), |
503 | mlir::AffineMap::get(0, 1, |
504 | 1 + mlir::getAffineSymbolExpr(0, op.getContext())), |
505 | step); |
506 | return std::make_pair(affineFor, affineFor.getInductionVar()); |
507 | } |
508 | |
509 | std::pair<affine::AffineForOp, mlir::Value> |
510 | genericBounds(fir::DoLoopOp op, mlir::PatternRewriter &rewriter) const { |
511 | auto lowerBound = mlir::getAffineSymbolExpr(0, op.getContext()); |
512 | auto upperBound = mlir::getAffineSymbolExpr(1, op.getContext()); |
513 | auto step = mlir::getAffineSymbolExpr(2, op.getContext()); |
514 | mlir::AffineMap upperBoundMap = mlir::AffineMap::get( |
515 | 0, 3, (upperBound - lowerBound + step).floorDiv(step)); |
516 | auto genericUpperBound = rewriter.create<affine::AffineApplyOp>( |
517 | op.getLoc(), upperBoundMap, |
518 | ValueRange({op.getLowerBound(), op.getUpperBound(), op.getStep()})); |
519 | auto actualIndexMap = mlir::AffineMap::get( |
520 | 1, 2, |
521 | (lowerBound + mlir::getAffineDimExpr(0, op.getContext())) * |
522 | mlir::getAffineSymbolExpr(1, op.getContext())); |
523 | |
524 | auto affineFor = rewriter.create<affine::AffineForOp>( |
525 | op.getLoc(), ValueRange(), |
526 | AffineMap::getConstantMap(0, op.getContext()), |
527 | genericUpperBound.getResult(), |
528 | mlir::AffineMap::get(0, 1, |
529 | 1 + mlir::getAffineSymbolExpr(0, op.getContext())), |
530 | 1); |
531 | rewriter.setInsertionPointToStart(affineFor.getBody()); |
532 | auto actualIndex = rewriter.create<affine::AffineApplyOp>( |
533 | op.getLoc(), actualIndexMap, |
534 | ValueRange( |
535 | {affineFor.getInductionVar(), op.getLowerBound(), op.getStep()})); |
536 | return std::make_pair(affineFor, actualIndex.getResult()); |
537 | } |
538 | |
539 | AffineFunctionAnalysis &functionAnalysis; |
540 | }; |
541 | |
542 | /// Convert `fir.if` to `affine.if`. |
543 | class AffineIfConversion : public mlir::OpRewritePattern<fir::IfOp> { |
544 | public: |
545 | using OpRewritePattern::OpRewritePattern; |
546 | AffineIfConversion(mlir::MLIRContext *context, AffineFunctionAnalysis &afa) |
547 | : OpRewritePattern(context) {} |
548 | mlir::LogicalResult |
549 | matchAndRewrite(fir::IfOp op, |
550 | mlir::PatternRewriter &rewriter) const override { |
551 | LLVM_DEBUG(llvm::dbgs() << "AffineIfConversion: rewriting if:\n" ; |
552 | op.dump();); |
553 | auto &ifOps = op.getThenRegion().front().getOperations(); |
554 | auto affineCondition = AffineIfCondition(op.getCondition()); |
555 | if (!affineCondition.hasIntegerSet()) { |
556 | LLVM_DEBUG( |
557 | llvm::dbgs() |
558 | << "AffineIfConversion: couldn't calculate affine condition\n" ;); |
559 | return failure(); |
560 | } |
561 | auto affineIf = rewriter.create<affine::AffineIfOp>( |
562 | op.getLoc(), affineCondition.getIntegerSet(), |
563 | affineCondition.getAffineArgs(), !op.getElseRegion().empty()); |
564 | rewriter.startOpModification(affineIf); |
565 | affineIf.getThenBlock()->getOperations().splice( |
566 | std::prev(affineIf.getThenBlock()->end()), ifOps, ifOps.begin(), |
567 | std::prev(ifOps.end())); |
568 | if (!op.getElseRegion().empty()) { |
569 | auto &otherOps = op.getElseRegion().front().getOperations(); |
570 | affineIf.getElseBlock()->getOperations().splice( |
571 | std::prev(affineIf.getElseBlock()->end()), otherOps, otherOps.begin(), |
572 | std::prev(otherOps.end())); |
573 | } |
574 | rewriter.finalizeOpModification(affineIf); |
575 | rewriteMemoryOps(affineIf.getBody(), rewriter); |
576 | |
577 | LLVM_DEBUG(llvm::dbgs() << "AffineIfConversion: if converted to:\n" ; |
578 | affineIf.dump();); |
579 | rewriter.replaceOp(op, affineIf.getOperation()->getResults()); |
580 | return success(); |
581 | } |
582 | }; |
583 | |
584 | /// Promote fir.do_loop and fir.if to affine.for and affine.if, in the cases |
585 | /// where such a promotion is possible. |
586 | class AffineDialectPromotion |
587 | : public fir::impl::AffineDialectPromotionBase<AffineDialectPromotion> { |
588 | public: |
589 | void runOnOperation() override { |
590 | |
591 | auto *context = &getContext(); |
592 | auto function = getOperation(); |
593 | markAllAnalysesPreserved(); |
594 | auto functionAnalysis = AffineFunctionAnalysis(function); |
595 | mlir::RewritePatternSet patterns(context); |
596 | patterns.insert<AffineIfConversion>(context, functionAnalysis); |
597 | patterns.insert<AffineLoopConversion>(context, functionAnalysis); |
598 | mlir::ConversionTarget target = *context; |
599 | target.addLegalDialect<mlir::affine::AffineDialect, FIROpsDialect, |
600 | mlir::scf::SCFDialect, mlir::arith::ArithDialect, |
601 | mlir::func::FuncDialect>(); |
602 | target.addDynamicallyLegalOp<IfOp>([&functionAnalysis](fir::IfOp op) { |
603 | return !(functionAnalysis.getChildIfAnalysis(op).canPromoteToAffine()); |
604 | }); |
605 | target.addDynamicallyLegalOp<DoLoopOp>([&functionAnalysis]( |
606 | fir::DoLoopOp op) { |
607 | return !(functionAnalysis.getChildLoopAnalysis(op).canPromoteToAffine()); |
608 | }); |
609 | |
610 | LLVM_DEBUG(llvm::dbgs() |
611 | << "AffineDialectPromotion: running promotion on: \n" ; |
612 | function.print(llvm::dbgs());); |
613 | // apply the patterns |
614 | if (mlir::failed(mlir::applyPartialConversion(function, target, |
615 | std::move(patterns)))) { |
616 | mlir::emitError(mlir::UnknownLoc::get(context), |
617 | "error in converting to affine dialect\n" ); |
618 | signalPassFailure(); |
619 | } |
620 | } |
621 | }; |
622 | } // namespace |
623 | |
624 | /// Convert FIR loop constructs to the Affine dialect |
625 | std::unique_ptr<mlir::Pass> fir::createPromoteToAffinePass() { |
626 | return std::make_unique<AffineDialectPromotion>(); |
627 | } |
628 | |