1//===- LowerWorkshare.cpp - special cases for bufferization -------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the lowering of omp.workshare to other omp constructs.
10//
11// This pass is tasked with parallelizing the loops nested in
12// workshare.loop_wrapper while both the Fortran to mlir lowering and the hlfir
13// to fir lowering pipelines are responsible for emitting the
14// workshare.loop_wrapper ops where appropriate according to the
15// `shouldUseWorkshareLowering` function.
16//
17//===----------------------------------------------------------------------===//
18
19#include <flang/Optimizer/Builder/FIRBuilder.h>
20#include <flang/Optimizer/Dialect/FIROps.h>
21#include <flang/Optimizer/Dialect/FIRType.h>
22#include <flang/Optimizer/HLFIR/HLFIROps.h>
23#include <flang/Optimizer/OpenMP/Passes.h>
24#include <llvm/ADT/BreadthFirstIterator.h>
25#include <llvm/ADT/STLExtras.h>
26#include <llvm/ADT/SmallVectorExtras.h>
27#include <llvm/ADT/iterator_range.h>
28#include <llvm/Support/ErrorHandling.h>
29#include <mlir/Dialect/Arith/IR/Arith.h>
30#include <mlir/Dialect/LLVMIR/LLVMTypes.h>
31#include <mlir/Dialect/OpenMP/OpenMPClauseOperands.h>
32#include <mlir/Dialect/OpenMP/OpenMPDialect.h>
33#include <mlir/Dialect/SCF/IR/SCF.h>
34#include <mlir/IR/BuiltinOps.h>
35#include <mlir/IR/IRMapping.h>
36#include <mlir/IR/OpDefinition.h>
37#include <mlir/IR/PatternMatch.h>
38#include <mlir/IR/Value.h>
39#include <mlir/IR/Visitors.h>
40#include <mlir/Interfaces/SideEffectInterfaces.h>
41#include <mlir/Support/LLVM.h>
42
43#include <variant>
44
45namespace flangomp {
46#define GEN_PASS_DEF_LOWERWORKSHARE
47#include "flang/Optimizer/OpenMP/Passes.h.inc"
48} // namespace flangomp
49
50#define DEBUG_TYPE "lower-workshare"
51
52using namespace mlir;
53
54namespace flangomp {
55
56// Checks for nesting pattern below as we need to avoid sharing the work of
57// statements which are nested in some constructs such as omp.critical or
58// another omp.parallel.
59//
60// omp.workshare { // `wsOp`
61// ...
62// omp.T { // `parent`
63// ...
64// `op`
65//
66template <typename T>
67static bool isNestedIn(omp::WorkshareOp wsOp, Operation *op) {
68 T parent = op->getParentOfType<T>();
69 if (!parent)
70 return false;
71 return wsOp->isProperAncestor(parent);
72}
73
74bool shouldUseWorkshareLowering(Operation *op) {
75 auto parentWorkshare = op->getParentOfType<omp::WorkshareOp>();
76
77 if (!parentWorkshare)
78 return false;
79
80 if (isNestedIn<omp::CriticalOp>(parentWorkshare, op))
81 return false;
82
83 // 2.8.3 workshare Construct
84 // For a parallel construct, the construct is a unit of work with respect to
85 // the workshare construct. The statements contained in the parallel construct
86 // are executed by a new thread team.
87 if (isNestedIn<omp::ParallelOp>(parentWorkshare, op))
88 return false;
89
90 // 2.8.2 single Construct
91 // Binding The binding thread set for a single region is the current team. A
92 // single region binds to the innermost enclosing parallel region.
93 // Description Only one of the encountering threads will execute the
94 // structured block associated with the single construct.
95 if (isNestedIn<omp::SingleOp>(parentWorkshare, op))
96 return false;
97
98 // Do not use workshare lowering until we support CFG in omp.workshare
99 if (parentWorkshare.getRegion().getBlocks().size() != 1)
100 return false;
101
102 return true;
103}
104
105} // namespace flangomp
106
107namespace {
108
109struct SingleRegion {
110 Block::iterator begin, end;
111};
112
113static bool mustParallelizeOp(Operation *op) {
114 return op
115 ->walk([&](Operation *nested) {
116 // We need to be careful not to pick up workshare.loop_wrapper in nested
117 // omp.parallel{omp.workshare} regions, i.e. make sure that `nested`
118 // binds to the workshare region we are currently handling.
119 //
120 // For example:
121 //
122 // omp.parallel {
123 // omp.workshare { // currently handling this
124 // omp.parallel {
125 // omp.workshare { // nested workshare
126 // omp.workshare.loop_wrapper {}
127 //
128 // Therefore, we skip if we encounter a nested omp.workshare.
129 if (isa<omp::WorkshareOp>(nested))
130 return WalkResult::skip();
131 if (isa<omp::WorkshareLoopWrapperOp>(nested))
132 return WalkResult::interrupt();
133 return WalkResult::advance();
134 })
135 .wasInterrupted();
136}
137
138static bool isSafeToParallelize(Operation *op) {
139 return isa<hlfir::DeclareOp>(op) || isa<fir::DeclareOp>(op) ||
140 isMemoryEffectFree(op);
141}
142
143/// Simple shallow copies suffice for our purposes in this pass, so we implement
144/// this simpler alternative to the full fledged `createCopyFunc` in the
145/// frontend
146static mlir::func::FuncOp createCopyFunc(mlir::Location loc, mlir::Type varType,
147 fir::FirOpBuilder builder) {
148 mlir::ModuleOp module = builder.getModule();
149 auto rt = cast<fir::ReferenceType>(varType);
150 mlir::Type eleTy = rt.getEleTy();
151 std::string copyFuncName =
152 fir::getTypeAsString(eleTy, builder.getKindMap(), "_workshare_copy");
153
154 if (auto decl = module.lookupSymbol<mlir::func::FuncOp>(copyFuncName))
155 return decl;
156
157 // create function
158 mlir::OpBuilder::InsertionGuard guard(builder);
159 mlir::OpBuilder modBuilder(module.getBodyRegion());
160 llvm::SmallVector<mlir::Type> argsTy = {varType, varType};
161 auto funcType = mlir::FunctionType::get(builder.getContext(), argsTy, {});
162 mlir::func::FuncOp funcOp =
163 modBuilder.create<mlir::func::FuncOp>(loc, copyFuncName, funcType);
164 funcOp.setVisibility(mlir::SymbolTable::Visibility::Private);
165 fir::factory::setInternalLinkage(funcOp);
166 builder.createBlock(&funcOp.getRegion(), funcOp.getRegion().end(), argsTy,
167 {loc, loc});
168 builder.setInsertionPointToStart(&funcOp.getRegion().back());
169
170 Value loaded = builder.create<fir::LoadOp>(loc, funcOp.getArgument(1));
171 builder.create<fir::StoreOp>(loc, loaded, funcOp.getArgument(0));
172
173 builder.create<mlir::func::ReturnOp>(loc);
174 return funcOp;
175}
176
177static bool isUserOutsideSR(Operation *user, Operation *parentOp,
178 SingleRegion sr) {
179 while (user->getParentOp() != parentOp)
180 user = user->getParentOp();
181 return sr.begin->getBlock() != user->getBlock() ||
182 !(user->isBeforeInBlock(&*sr.end) && sr.begin->isBeforeInBlock(user));
183}
184
185static bool isTransitivelyUsedOutside(Value v, SingleRegion sr) {
186 Block *srBlock = sr.begin->getBlock();
187 Operation *parentOp = srBlock->getParentOp();
188
189 for (auto &use : v.getUses()) {
190 Operation *user = use.getOwner();
191 if (isUserOutsideSR(user, parentOp, sr))
192 return true;
193
194 // Now we know user is inside `sr`.
195
196 // Results of nested users cannot be used outside of `sr`.
197 if (user->getBlock() != srBlock)
198 continue;
199
200 // A non-safe to parallelize operation will be checked for uses outside
201 // separately.
202 if (!isSafeToParallelize(user))
203 continue;
204
205 // For safe to parallelize operations, we need to check if there is a
206 // transitive use of `v` through them.
207 for (auto res : user->getResults())
208 if (isTransitivelyUsedOutside(res, sr))
209 return true;
210 }
211 return false;
212}
213
214/// We clone pure operations in both the parallel and single blocks. this
215/// functions cleans them up if they end up with no uses
216static void cleanupBlock(Block *block) {
217 for (Operation &op : llvm::make_early_inc_range(
218 llvm::make_range(block->rbegin(), block->rend())))
219 if (isOpTriviallyDead(&op))
220 op.erase();
221}
222
223static void parallelizeRegion(Region &sourceRegion, Region &targetRegion,
224 IRMapping &rootMapping, Location loc,
225 mlir::DominanceInfo &di) {
226 OpBuilder rootBuilder(sourceRegion.getContext());
227 ModuleOp m = sourceRegion.getParentOfType<ModuleOp>();
228 OpBuilder copyFuncBuilder(m.getBodyRegion());
229 fir::FirOpBuilder firCopyFuncBuilder(copyFuncBuilder, m);
230
231 auto mapReloadedValue =
232 [&](Value v, OpBuilder allocaBuilder, OpBuilder singleBuilder,
233 OpBuilder parallelBuilder, IRMapping singleMapping) -> Value {
234 if (auto reloaded = rootMapping.lookupOrNull(v))
235 return nullptr;
236 Type ty = v.getType();
237 Value alloc = allocaBuilder.create<fir::AllocaOp>(loc, ty);
238 singleBuilder.create<fir::StoreOp>(loc, singleMapping.lookup(v), alloc);
239 Value reloaded = parallelBuilder.create<fir::LoadOp>(loc, ty, alloc);
240 rootMapping.map(v, reloaded);
241 return alloc;
242 };
243
244 auto moveToSingle =
245 [&](SingleRegion sr, OpBuilder allocaBuilder, OpBuilder singleBuilder,
246 OpBuilder parallelBuilder) -> std::pair<bool, SmallVector<Value>> {
247 IRMapping singleMapping = rootMapping;
248 SmallVector<Value> copyPrivate;
249 bool allParallelized = true;
250
251 for (Operation &op : llvm::make_range(sr.begin, sr.end)) {
252 if (isSafeToParallelize(&op)) {
253 singleBuilder.clone(op, singleMapping);
254 if (llvm::all_of(op.getOperands(), [&](Value opr) {
255 // Either we have already remapped it
256 bool remapped = rootMapping.contains(opr);
257 // Or it is available because it dominates `sr`
258 bool dominates = di.properlyDominates(opr, &*sr.begin);
259 return remapped || dominates;
260 })) {
261 // Safe to parallelize operations which have all operands available in
262 // the root parallel block can be executed there.
263 parallelBuilder.clone(op, rootMapping);
264 } else {
265 // If any operand was not available, it means that there was no
266 // transitive use of a non-safe-to-parallelize operation outside `sr`.
267 // This means that there should be no transitive uses outside `sr` of
268 // `op`.
269 assert(llvm::all_of(op.getResults(), [&](Value v) {
270 return !isTransitivelyUsedOutside(v, sr);
271 }));
272 allParallelized = false;
273 }
274 } else if (auto alloca = dyn_cast<fir::AllocaOp>(&op)) {
275 auto hoisted =
276 cast<fir::AllocaOp>(allocaBuilder.clone(*alloca, singleMapping));
277 rootMapping.map(&*alloca, &*hoisted);
278 rootMapping.map(alloca.getResult(), hoisted.getResult());
279 copyPrivate.push_back(hoisted);
280 allParallelized = false;
281 } else {
282 singleBuilder.clone(op, singleMapping);
283 // Prepare reloaded values for results of operations that cannot be
284 // safely parallelized and which are used after the region `sr`.
285 for (auto res : op.getResults()) {
286 if (isTransitivelyUsedOutside(res, sr)) {
287 auto alloc = mapReloadedValue(res, allocaBuilder, singleBuilder,
288 parallelBuilder, singleMapping);
289 if (alloc)
290 copyPrivate.push_back(alloc);
291 }
292 }
293 allParallelized = false;
294 }
295 }
296 singleBuilder.create<omp::TerminatorOp>(loc);
297 return {allParallelized, copyPrivate};
298 };
299
300 for (Block &block : sourceRegion) {
301 Block *targetBlock = rootBuilder.createBlock(
302 &targetRegion, {}, block.getArgumentTypes(),
303 llvm::map_to_vector(block.getArguments(),
304 [](BlockArgument arg) { return arg.getLoc(); }));
305 rootMapping.map(&block, targetBlock);
306 rootMapping.map(block.getArguments(), targetBlock->getArguments());
307 }
308
309 auto handleOneBlock = [&](Block &block) {
310 Block &targetBlock = *rootMapping.lookup(&block);
311 rootBuilder.setInsertionPointToStart(&targetBlock);
312 Operation *terminator = block.getTerminator();
313 SmallVector<std::variant<SingleRegion, Operation *>> regions;
314
315 auto it = block.begin();
316 auto getOneRegion = [&]() {
317 if (&*it == terminator)
318 return false;
319 if (mustParallelizeOp(&*it)) {
320 regions.push_back(&*it);
321 it++;
322 return true;
323 }
324 SingleRegion sr;
325 sr.begin = it;
326 while (&*it != terminator && !mustParallelizeOp(&*it))
327 it++;
328 sr.end = it;
329 assert(sr.begin != sr.end);
330 regions.push_back(sr);
331 return true;
332 };
333 while (getOneRegion())
334 ;
335
336 for (auto [i, opOrSingle] : llvm::enumerate(regions)) {
337 bool isLast = i + 1 == regions.size();
338 if (std::holds_alternative<SingleRegion>(opOrSingle)) {
339 OpBuilder singleBuilder(sourceRegion.getContext());
340 Block *singleBlock = new Block();
341 singleBuilder.setInsertionPointToStart(singleBlock);
342
343 OpBuilder allocaBuilder(sourceRegion.getContext());
344 Block *allocaBlock = new Block();
345 allocaBuilder.setInsertionPointToStart(allocaBlock);
346
347 OpBuilder parallelBuilder(sourceRegion.getContext());
348 Block *parallelBlock = new Block();
349 parallelBuilder.setInsertionPointToStart(parallelBlock);
350
351 auto [allParallelized, copyprivateVars] =
352 moveToSingle(std::get<SingleRegion>(opOrSingle), allocaBuilder,
353 singleBuilder, parallelBuilder);
354 if (allParallelized) {
355 // The single region was not required as all operations were safe to
356 // parallelize
357 assert(copyprivateVars.empty());
358 assert(allocaBlock->empty());
359 delete singleBlock;
360 } else {
361 omp::SingleOperands singleOperands;
362 if (isLast)
363 singleOperands.nowait = rootBuilder.getUnitAttr();
364 singleOperands.copyprivateVars = copyprivateVars;
365 cleanupBlock(singleBlock);
366 for (auto var : singleOperands.copyprivateVars) {
367 mlir::func::FuncOp funcOp =
368 createCopyFunc(loc, var.getType(), firCopyFuncBuilder);
369 singleOperands.copyprivateSyms.push_back(
370 SymbolRefAttr::get(funcOp));
371 }
372 omp::SingleOp singleOp =
373 rootBuilder.create<omp::SingleOp>(loc, singleOperands);
374 singleOp.getRegion().push_back(singleBlock);
375 targetRegion.front().getOperations().splice(
376 singleOp->getIterator(), allocaBlock->getOperations());
377 }
378 rootBuilder.getInsertionBlock()->getOperations().splice(
379 rootBuilder.getInsertionPoint(), parallelBlock->getOperations());
380 delete allocaBlock;
381 delete parallelBlock;
382 } else {
383 auto op = std::get<Operation *>(opOrSingle);
384 if (auto wslw = dyn_cast<omp::WorkshareLoopWrapperOp>(op)) {
385 omp::WsloopOperands wsloopOperands;
386 if (isLast)
387 wsloopOperands.nowait = rootBuilder.getUnitAttr();
388 auto wsloop =
389 rootBuilder.create<mlir::omp::WsloopOp>(loc, wsloopOperands);
390 auto clonedWslw = cast<omp::WorkshareLoopWrapperOp>(
391 rootBuilder.clone(*wslw, rootMapping));
392 wsloop.getRegion().takeBody(clonedWslw.getRegion());
393 clonedWslw->erase();
394 } else {
395 assert(mustParallelizeOp(op));
396 Operation *cloned = rootBuilder.cloneWithoutRegions(*op, rootMapping);
397 for (auto [region, clonedRegion] :
398 llvm::zip(op->getRegions(), cloned->getRegions()))
399 parallelizeRegion(region, clonedRegion, rootMapping, loc, di);
400 }
401 }
402 }
403
404 rootBuilder.clone(*block.getTerminator(), rootMapping);
405 };
406
407 if (sourceRegion.hasOneBlock()) {
408 handleOneBlock(sourceRegion.front());
409 } else if (!sourceRegion.empty()) {
410 auto &domTree = di.getDomTree(&sourceRegion);
411 for (auto node : llvm::breadth_first(domTree.getRootNode())) {
412 handleOneBlock(*node->getBlock());
413 }
414 }
415
416 for (Block &targetBlock : targetRegion)
417 cleanupBlock(&targetBlock);
418}
419
420/// Lowers workshare to a sequence of single-thread regions and parallel loops
421///
422/// For example:
423///
424/// omp.workshare {
425/// %a = fir.allocmem
426/// omp.workshare.loop_wrapper {}
427/// fir.call Assign %b %a
428/// fir.freemem %a
429/// }
430///
431/// becomes
432///
433/// %tmp = fir.alloca
434/// omp.single copyprivate(%tmp) {
435/// %a = fir.allocmem
436/// fir.store %a %tmp
437/// }
438/// %a_reloaded = fir.load %tmp
439/// omp.workshare.loop_wrapper {}
440/// omp.single {
441/// fir.call Assign %b %a_reloaded
442/// fir.freemem %a_reloaded
443/// }
444///
445/// Note that we allocate temporary memory for values in omp.single's which need
446/// to be accessed by all threads and broadcast them using single's copyprivate
447LogicalResult lowerWorkshare(mlir::omp::WorkshareOp wsOp, DominanceInfo &di) {
448 Location loc = wsOp->getLoc();
449 IRMapping rootMapping;
450
451 OpBuilder rootBuilder(wsOp);
452
453 // FIXME Currently, we only support workshare constructs with structured
454 // control flow. The transformation itself supports CFG, however, once we
455 // transform the MLIR region in the omp.workshare, we need to inline that
456 // region in the parent block. We have no guarantees at this point of the
457 // pipeline that the parent op supports CFG (e.g. fir.if), thus this is not
458 // generally possible. The alternative is to put the lowered region in an
459 // operation akin to scf.execute_region, which will get lowered at the same
460 // time when fir ops get lowered to CFG. However, SCF is not registered in
461 // flang so we cannot use it. Remove this requirement once we have
462 // scf.execute_region or an alternative operation available.
463 if (wsOp.getRegion().getBlocks().size() == 1) {
464 // This operation is just a placeholder which will be erased later. We need
465 // it because our `parallelizeRegion` function works on regions and not
466 // blocks.
467 omp::WorkshareOp newOp =
468 rootBuilder.create<omp::WorkshareOp>(loc, omp::WorkshareOperands());
469 if (!wsOp.getNowait())
470 rootBuilder.create<omp::BarrierOp>(loc);
471
472 parallelizeRegion(wsOp.getRegion(), newOp.getRegion(), rootMapping, loc,
473 di);
474
475 // Inline the contents of the placeholder workshare op into its parent
476 // block.
477 Block *theBlock = &newOp.getRegion().front();
478 Operation *term = theBlock->getTerminator();
479 Block *parentBlock = wsOp->getBlock();
480 parentBlock->getOperations().splice(newOp->getIterator(),
481 theBlock->getOperations());
482 assert(term->getNumOperands() == 0);
483 term->erase();
484 newOp->erase();
485 wsOp->erase();
486 } else {
487 // Otherwise just change the operation to an omp.single.
488
489 wsOp->emitWarning(
490 "omp workshare with unstructured control flow is currently "
491 "unsupported and will be serialized.");
492
493 // `shouldUseWorkshareLowering` should have guaranteed that there are no
494 // omp.workshare_loop_wrapper's that bind to this omp.workshare.
495 assert(!wsOp->walk([&](Operation *op) {
496 // Nested omp.workshare can have their own
497 // omp.workshare_loop_wrapper's.
498 if (isa<omp::WorkshareOp>(op))
499 return WalkResult::skip();
500 if (isa<omp::WorkshareLoopWrapperOp>(op))
501 return WalkResult::interrupt();
502 return WalkResult::advance();
503 })
504 .wasInterrupted());
505
506 omp::SingleOperands operands;
507 operands.nowait = wsOp.getNowaitAttr();
508 omp::SingleOp newOp = rootBuilder.create<omp::SingleOp>(loc, operands);
509
510 newOp.getRegion().getBlocks().splice(newOp.getRegion().getBlocks().begin(),
511 wsOp.getRegion().getBlocks());
512 wsOp->erase();
513 }
514 return success();
515}
516
517class LowerWorksharePass
518 : public flangomp::impl::LowerWorkshareBase<LowerWorksharePass> {
519public:
520 void runOnOperation() override {
521 mlir::DominanceInfo &di = getAnalysis<mlir::DominanceInfo>();
522 getOperation()->walk([&](mlir::omp::WorkshareOp wsOp) {
523 if (failed(lowerWorkshare(wsOp, di)))
524 signalPassFailure();
525 });
526 }
527};
528} // namespace
529

source code of flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp