1//===-- VectorSubscripts.cpp -- Vector subscripts tools -------------------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/
10//
11//===----------------------------------------------------------------------===//
12
13#include "flang/Lower/VectorSubscripts.h"
14#include "flang/Lower/AbstractConverter.h"
15#include "flang/Lower/Support/Utils.h"
16#include "flang/Optimizer/Builder/Character.h"
17#include "flang/Optimizer/Builder/Complex.h"
18#include "flang/Optimizer/Builder/FIRBuilder.h"
19#include "flang/Optimizer/Builder/Todo.h"
20#include "flang/Semantics/expression.h"
21
22namespace {
23/// Helper class to lower a designator containing vector subscripts into a
24/// lowered representation that can be worked with.
25class VectorSubscriptBoxBuilder {
26public:
27 VectorSubscriptBoxBuilder(mlir::Location loc,
28 Fortran::lower::AbstractConverter &converter,
29 Fortran::lower::StatementContext &stmtCtx)
30 : converter{converter}, stmtCtx{stmtCtx}, loc{loc} {}
31
32 Fortran::lower::VectorSubscriptBox gen(const Fortran::lower::SomeExpr &expr) {
33 elementType = genDesignator(expr);
34 return Fortran::lower::VectorSubscriptBox(
35 std::move(loweredBase), std::move(loweredSubscripts),
36 std::move(componentPath), substringBounds, elementType);
37 }
38
39private:
40 using LoweredVectorSubscript =
41 Fortran::lower::VectorSubscriptBox::LoweredVectorSubscript;
42 using LoweredTriplet = Fortran::lower::VectorSubscriptBox::LoweredTriplet;
43 using LoweredSubscript = Fortran::lower::VectorSubscriptBox::LoweredSubscript;
44 using MaybeSubstring = Fortran::lower::VectorSubscriptBox::MaybeSubstring;
45
46 /// genDesignator unwraps a Designator<T> and calls `gen` on what the
47 /// designator actually contains.
48 template <typename A>
49 mlir::Type genDesignator(const A &) {
50 fir::emitFatalError(loc, "expr must contain a designator");
51 }
52 template <typename T>
53 mlir::Type genDesignator(const Fortran::evaluate::Expr<T> &expr) {
54 using ExprVariant = decltype(Fortran::evaluate::Expr<T>::u);
55 using Designator = Fortran::evaluate::Designator<T>;
56 if constexpr (Fortran::common::HasMember<Designator, ExprVariant>) {
57 const auto &designator = std::get<Designator>(expr.u);
58 return std::visit([&](const auto &x) { return gen(x); }, designator.u);
59 } else {
60 return std::visit([&](const auto &x) { return genDesignator(x); },
61 expr.u);
62 }
63 }
64
65 // The gen(X) methods visit X to lower its base and subscripts and return the
66 // type of X elements.
67
68 mlir::Type gen(const Fortran::evaluate::DataRef &dataRef) {
69 return std::visit([&](const auto &ref) -> mlir::Type { return gen(ref); },
70 dataRef.u);
71 }
72
73 mlir::Type gen(const Fortran::evaluate::SymbolRef &symRef) {
74 // Never visited because expr lowering is used to lowered the ranked
75 // ArrayRef.
76 fir::emitFatalError(
77 loc, "expected at least one ArrayRef with vector susbcripts");
78 }
79
80 mlir::Type gen(const Fortran::evaluate::Substring &substring) {
81 // StaticDataObject::Pointer bases are constants and cannot be
82 // subscripted, so the base must be a DataRef here.
83 mlir::Type baseElementType =
84 gen(std::get<Fortran::evaluate::DataRef>(substring.parent()));
85 fir::FirOpBuilder &builder = converter.getFirOpBuilder();
86 mlir::Type idxTy = builder.getIndexType();
87 mlir::Value lb = genScalarValue(substring.lower());
88 substringBounds.emplace_back(builder.createConvert(loc, idxTy, lb));
89 if (const auto &ubExpr = substring.upper()) {
90 mlir::Value ub = genScalarValue(*ubExpr);
91 substringBounds.emplace_back(builder.createConvert(loc, idxTy, ub));
92 }
93 return baseElementType;
94 }
95
96 mlir::Type gen(const Fortran::evaluate::ComplexPart &complexPart) {
97 auto complexType = gen(complexPart.complex());
98 fir::FirOpBuilder &builder = converter.getFirOpBuilder();
99 mlir::Type i32Ty = builder.getI32Type(); // llvm's GEP requires i32
100 mlir::Value offset = builder.createIntegerConstant(
101 loc, i32Ty,
102 complexPart.part() == Fortran::evaluate::ComplexPart::Part::RE ? 0 : 1);
103 componentPath.emplace_back(offset);
104 return fir::factory::Complex{builder, loc}.getComplexPartType(complexType);
105 }
106
107 mlir::Type gen(const Fortran::evaluate::Component &component) {
108 auto recTy = gen(component.base()).cast<fir::RecordType>();
109 const Fortran::semantics::Symbol &componentSymbol =
110 component.GetLastSymbol();
111 // Parent components will not be found here, they are not part
112 // of the FIR type and cannot be used in the path yet.
113 if (componentSymbol.test(Fortran::semantics::Symbol::Flag::ParentComp))
114 TODO(loc, "reference to parent component");
115 mlir::Type fldTy = fir::FieldType::get(&converter.getMLIRContext());
116 llvm::StringRef componentName = toStringRef(componentSymbol.name());
117 // Parameters threading in field_index is not yet very clear. We only
118 // have the ones of the ranked array ref at hand, but it looks like
119 // the fir.field_index expects the one of the direct base.
120 if (recTy.getNumLenParams() != 0)
121 TODO(loc, "threading length parameters in field index op");
122 fir::FirOpBuilder &builder = converter.getFirOpBuilder();
123 componentPath.emplace_back(builder.create<fir::FieldIndexOp>(
124 loc, fldTy, componentName, recTy, /*typeParams*/ std::nullopt));
125 return fir::unwrapSequenceType(recTy.getType(componentName));
126 }
127
128 mlir::Type gen(const Fortran::evaluate::ArrayRef &arrayRef) {
129 auto isTripletOrVector =
130 [](const Fortran::evaluate::Subscript &subscript) -> bool {
131 return std::visit(
132 Fortran::common::visitors{
133 [](const Fortran::evaluate::IndirectSubscriptIntegerExpr &expr) {
134 return expr.value().Rank() != 0;
135 },
136 [&](const Fortran::evaluate::Triplet &) { return true; }},
137 subscript.u);
138 };
139 if (llvm::any_of(arrayRef.subscript(), isTripletOrVector))
140 return genRankedArrayRefSubscriptAndBase(arrayRef);
141
142 // This is a scalar ArrayRef (only scalar indexes), collect the indexes and
143 // visit the base that must contain another arrayRef with the vector
144 // subscript.
145 mlir::Type elementType = gen(namedEntityToDataRef(arrayRef.base()));
146 for (const Fortran::evaluate::Subscript &subscript : arrayRef.subscript()) {
147 const auto &expr =
148 std::get<Fortran::evaluate::IndirectSubscriptIntegerExpr>(
149 subscript.u);
150 componentPath.emplace_back(genScalarValue(expr.value()));
151 }
152 return elementType;
153 }
154
155 /// Lower the subscripts and base of the ArrayRef that is an array (there must
156 /// be one since there is a vector subscript, and there can only be one
157 /// according to C925).
158 mlir::Type genRankedArrayRefSubscriptAndBase(
159 const Fortran::evaluate::ArrayRef &arrayRef) {
160 // Lower the save the base
161 Fortran::lower::SomeExpr baseExpr = namedEntityToExpr(arrayRef.base());
162 loweredBase = converter.genExprAddr(baseExpr, stmtCtx);
163 // Lower and save the subscripts
164 fir::FirOpBuilder &builder = converter.getFirOpBuilder();
165 mlir::Type idxTy = builder.getIndexType();
166 mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1);
167 for (const auto &subscript : llvm::enumerate(arrayRef.subscript())) {
168 std::visit(
169 Fortran::common::visitors{
170 [&](const Fortran::evaluate::IndirectSubscriptIntegerExpr &expr) {
171 if (expr.value().Rank() == 0) {
172 // Simple scalar subscript
173 loweredSubscripts.emplace_back(genScalarValue(expr.value()));
174 } else {
175 // Vector subscript.
176 // Remove conversion if any to avoid temp creation that may
177 // have been added by the front-end to avoid the creation of a
178 // temp array value.
179 auto vector = converter.genExprAddr(
180 ignoreEvConvert(expr.value()), stmtCtx);
181 mlir::Value size =
182 fir::factory::readExtent(builder, loc, vector, /*dim=*/0);
183 size = builder.createConvert(loc, idxTy, size);
184 loweredSubscripts.emplace_back(
185 LoweredVectorSubscript{std::move(vector), size});
186 }
187 },
188 [&](const Fortran::evaluate::Triplet &triplet) {
189 mlir::Value lb, ub;
190 if (const auto &lbExpr = triplet.lower())
191 lb = genScalarValue(*lbExpr);
192 else
193 lb = fir::factory::readLowerBound(builder, loc, loweredBase,
194 subscript.index(), one);
195 if (const auto &ubExpr = triplet.upper())
196 ub = genScalarValue(*ubExpr);
197 else
198 ub = fir::factory::readExtent(builder, loc, loweredBase,
199 subscript.index());
200 lb = builder.createConvert(loc, idxTy, lb);
201 ub = builder.createConvert(loc, idxTy, ub);
202 mlir::Value stride = genScalarValue(triplet.stride());
203 stride = builder.createConvert(loc, idxTy, stride);
204 loweredSubscripts.emplace_back(LoweredTriplet{lb, ub, stride});
205 },
206 },
207 subscript.value().u);
208 }
209 return fir::unwrapSequenceType(
210 fir::unwrapPassByRefType(fir::getBase(loweredBase).getType()));
211 }
212
213 mlir::Type gen(const Fortran::evaluate::CoarrayRef &) {
214 // Is this possible/legal ?
215 TODO(loc, "coarray: reference to coarray object with vector subscript in "
216 "IO input");
217 }
218
219 template <typename A>
220 mlir::Value genScalarValue(const A &expr) {
221 return fir::getBase(converter.genExprValue(toEvExpr(expr), stmtCtx));
222 }
223
224 Fortran::evaluate::DataRef
225 namedEntityToDataRef(const Fortran::evaluate::NamedEntity &namedEntity) {
226 if (namedEntity.IsSymbol())
227 return Fortran::evaluate::DataRef{namedEntity.GetFirstSymbol()};
228 return Fortran::evaluate::DataRef{namedEntity.GetComponent()};
229 }
230
231 Fortran::lower::SomeExpr
232 namedEntityToExpr(const Fortran::evaluate::NamedEntity &namedEntity) {
233 return Fortran::evaluate::AsGenericExpr(namedEntityToDataRef(namedEntity))
234 .value();
235 }
236
237 Fortran::lower::AbstractConverter &converter;
238 Fortran::lower::StatementContext &stmtCtx;
239 mlir::Location loc;
240 /// Elements of VectorSubscriptBox being built.
241 fir::ExtendedValue loweredBase;
242 llvm::SmallVector<LoweredSubscript, 16> loweredSubscripts;
243 llvm::SmallVector<mlir::Value> componentPath;
244 MaybeSubstring substringBounds;
245 mlir::Type elementType;
246};
247} // namespace
248
249Fortran::lower::VectorSubscriptBox Fortran::lower::genVectorSubscriptBox(
250 mlir::Location loc, Fortran::lower::AbstractConverter &converter,
251 Fortran::lower::StatementContext &stmtCtx,
252 const Fortran::lower::SomeExpr &expr) {
253 return VectorSubscriptBoxBuilder(loc, converter, stmtCtx).gen(expr);
254}
255
256template <typename LoopType, typename Generator>
257mlir::Value Fortran::lower::VectorSubscriptBox::loopOverElementsBase(
258 fir::FirOpBuilder &builder, mlir::Location loc,
259 const Generator &elementalGenerator,
260 [[maybe_unused]] mlir::Value initialCondition) {
261 mlir::Value shape = builder.createShape(loc, loweredBase);
262 mlir::Value slice = createSlice(builder, loc);
263
264 // Create loop nest for triplets and vector subscripts in column
265 // major order.
266 llvm::SmallVector<mlir::Value> inductionVariables;
267 LoopType outerLoop;
268 for (auto [lb, ub, step] : genLoopBounds(builder, loc)) {
269 LoopType loop;
270 if constexpr (std::is_same_v<LoopType, fir::IterWhileOp>) {
271 loop =
272 builder.create<fir::IterWhileOp>(loc, lb, ub, step, initialCondition);
273 initialCondition = loop.getIterateVar();
274 if (!outerLoop)
275 outerLoop = loop;
276 else
277 builder.create<fir::ResultOp>(loc, loop.getResult(0));
278 } else {
279 loop =
280 builder.create<fir::DoLoopOp>(loc, lb, ub, step, /*unordered=*/false);
281 if (!outerLoop)
282 outerLoop = loop;
283 }
284 builder.setInsertionPointToStart(loop.getBody());
285 inductionVariables.push_back(loop.getInductionVar());
286 }
287 assert(outerLoop && !inductionVariables.empty() &&
288 "at least one loop should be created");
289
290 fir::ExtendedValue elem =
291 getElementAt(builder, loc, shape, slice, inductionVariables);
292
293 if constexpr (std::is_same_v<LoopType, fir::IterWhileOp>) {
294 auto res = elementalGenerator(elem);
295 builder.create<fir::ResultOp>(loc, res);
296 builder.setInsertionPointAfter(outerLoop);
297 return outerLoop.getResult(0);
298 } else {
299 elementalGenerator(elem);
300 builder.setInsertionPointAfter(outerLoop);
301 return {};
302 }
303}
304
305void Fortran::lower::VectorSubscriptBox::loopOverElements(
306 fir::FirOpBuilder &builder, mlir::Location loc,
307 const ElementalGenerator &elementalGenerator) {
308 mlir::Value initialCondition;
309 loopOverElementsBase<fir::DoLoopOp, ElementalGenerator>(
310 builder, loc, elementalGenerator, initialCondition);
311}
312
313mlir::Value Fortran::lower::VectorSubscriptBox::loopOverElementsWhile(
314 fir::FirOpBuilder &builder, mlir::Location loc,
315 const ElementalGeneratorWithBoolReturn &elementalGenerator,
316 mlir::Value initialCondition) {
317 return loopOverElementsBase<fir::IterWhileOp,
318 ElementalGeneratorWithBoolReturn>(
319 builder, loc, elementalGenerator, initialCondition);
320}
321
322mlir::Value
323Fortran::lower::VectorSubscriptBox::createSlice(fir::FirOpBuilder &builder,
324 mlir::Location loc) {
325 mlir::Type idxTy = builder.getIndexType();
326 llvm::SmallVector<mlir::Value> triples;
327 mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1);
328 auto undef = builder.create<fir::UndefOp>(loc, idxTy);
329 for (const LoweredSubscript &subscript : loweredSubscripts)
330 std::visit(Fortran::common::visitors{
331 [&](const LoweredTriplet &triplet) {
332 triples.emplace_back(triplet.lb);
333 triples.emplace_back(triplet.ub);
334 triples.emplace_back(triplet.stride);
335 },
336 [&](const LoweredVectorSubscript &vector) {
337 triples.emplace_back(one);
338 triples.emplace_back(vector.size);
339 triples.emplace_back(one);
340 },
341 [&](const mlir::Value &i) {
342 triples.emplace_back(i);
343 triples.emplace_back(undef);
344 triples.emplace_back(undef);
345 },
346 },
347 subscript);
348 return builder.create<fir::SliceOp>(loc, triples, componentPath);
349}
350
351llvm::SmallVector<std::tuple<mlir::Value, mlir::Value, mlir::Value>>
352Fortran::lower::VectorSubscriptBox::genLoopBounds(fir::FirOpBuilder &builder,
353 mlir::Location loc) {
354 mlir::Type idxTy = builder.getIndexType();
355 mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1);
356 mlir::Value zero = builder.createIntegerConstant(loc, idxTy, 0);
357 llvm::SmallVector<std::tuple<mlir::Value, mlir::Value, mlir::Value>> bounds;
358 size_t dimension = loweredSubscripts.size();
359 for (const LoweredSubscript &subscript : llvm::reverse(loweredSubscripts)) {
360 --dimension;
361 if (std::holds_alternative<mlir::Value>(subscript))
362 continue;
363 mlir::Value lb, ub, step;
364 if (const auto *triplet = std::get_if<LoweredTriplet>(&subscript)) {
365 mlir::Value extent = builder.genExtentFromTriplet(
366 loc, triplet->lb, triplet->ub, triplet->stride, idxTy);
367 mlir::Value baseLb = fir::factory::readLowerBound(
368 builder, loc, loweredBase, dimension, one);
369 baseLb = builder.createConvert(loc, idxTy, baseLb);
370 lb = baseLb;
371 ub = builder.create<mlir::arith::SubIOp>(loc, idxTy, extent, one);
372 ub = builder.create<mlir::arith::AddIOp>(loc, idxTy, ub, baseLb);
373 step = one;
374 } else {
375 const auto &vector = std::get<LoweredVectorSubscript>(subscript);
376 lb = zero;
377 ub = builder.create<mlir::arith::SubIOp>(loc, idxTy, vector.size, one);
378 step = one;
379 }
380 bounds.emplace_back(lb, ub, step);
381 }
382 return bounds;
383}
384
385fir::ExtendedValue Fortran::lower::VectorSubscriptBox::getElementAt(
386 fir::FirOpBuilder &builder, mlir::Location loc, mlir::Value shape,
387 mlir::Value slice, mlir::ValueRange inductionVariables) {
388 /// Generate the indexes for the array_coor inside the loops.
389 mlir::Type idxTy = builder.getIndexType();
390 llvm::SmallVector<mlir::Value> indexes;
391 size_t inductionIdx = inductionVariables.size() - 1;
392 for (const LoweredSubscript &subscript : loweredSubscripts)
393 std::visit(Fortran::common::visitors{
394 [&](const LoweredTriplet &triplet) {
395 indexes.emplace_back(inductionVariables[inductionIdx--]);
396 },
397 [&](const LoweredVectorSubscript &vector) {
398 mlir::Value vecIndex = inductionVariables[inductionIdx--];
399 mlir::Value vecBase = fir::getBase(vector.vector);
400 mlir::Type vecEleTy = fir::unwrapSequenceType(
401 fir::unwrapPassByRefType(vecBase.getType()));
402 mlir::Type refTy = builder.getRefType(vecEleTy);
403 auto vecEltRef = builder.create<fir::CoordinateOp>(
404 loc, refTy, vecBase, vecIndex);
405 auto vecElt =
406 builder.create<fir::LoadOp>(loc, vecEleTy, vecEltRef);
407 indexes.emplace_back(
408 builder.createConvert(loc, idxTy, vecElt));
409 },
410 [&](const mlir::Value &i) {
411 indexes.emplace_back(builder.createConvert(loc, idxTy, i));
412 },
413 },
414 subscript);
415 mlir::Type refTy = builder.getRefType(getElementType());
416 auto elementAddr = builder.create<fir::ArrayCoorOp>(
417 loc, refTy, fir::getBase(loweredBase), shape, slice, indexes,
418 fir::getTypeParams(loweredBase));
419 fir::ExtendedValue element = fir::factory::arraySectionElementToExtendedValue(
420 builder, loc, loweredBase, elementAddr, slice);
421 if (!substringBounds.empty()) {
422 const fir::CharBoxValue *charBox = element.getCharBox();
423 assert(charBox && "substring requires CharBox base");
424 fir::factory::CharacterExprHelper helper{builder, loc};
425 return helper.createSubstring(*charBox, substringBounds);
426 }
427 return element;
428}
429

source code of flang/lib/Lower/VectorSubscripts.cpp