1 | //===-- VectorSubscripts.cpp -- Vector subscripts tools -------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "flang/Lower/VectorSubscripts.h" |
14 | #include "flang/Lower/AbstractConverter.h" |
15 | #include "flang/Lower/Support/Utils.h" |
16 | #include "flang/Optimizer/Builder/Character.h" |
17 | #include "flang/Optimizer/Builder/Complex.h" |
18 | #include "flang/Optimizer/Builder/FIRBuilder.h" |
19 | #include "flang/Optimizer/Builder/Todo.h" |
20 | #include "flang/Semantics/expression.h" |
21 | |
22 | namespace { |
23 | /// Helper class to lower a designator containing vector subscripts into a |
24 | /// lowered representation that can be worked with. |
25 | class VectorSubscriptBoxBuilder { |
26 | public: |
27 | VectorSubscriptBoxBuilder(mlir::Location loc, |
28 | Fortran::lower::AbstractConverter &converter, |
29 | Fortran::lower::StatementContext &stmtCtx) |
30 | : converter{converter}, stmtCtx{stmtCtx}, loc{loc} {} |
31 | |
32 | Fortran::lower::VectorSubscriptBox gen(const Fortran::lower::SomeExpr &expr) { |
33 | elementType = genDesignator(expr); |
34 | return Fortran::lower::VectorSubscriptBox( |
35 | std::move(loweredBase), std::move(loweredSubscripts), |
36 | std::move(componentPath), substringBounds, elementType); |
37 | } |
38 | |
39 | private: |
40 | using LoweredVectorSubscript = |
41 | Fortran::lower::VectorSubscriptBox::LoweredVectorSubscript; |
42 | using LoweredTriplet = Fortran::lower::VectorSubscriptBox::LoweredTriplet; |
43 | using LoweredSubscript = Fortran::lower::VectorSubscriptBox::LoweredSubscript; |
44 | using MaybeSubstring = Fortran::lower::VectorSubscriptBox::MaybeSubstring; |
45 | |
46 | /// genDesignator unwraps a Designator<T> and calls `gen` on what the |
47 | /// designator actually contains. |
48 | template <typename A> |
49 | mlir::Type genDesignator(const A &) { |
50 | fir::emitFatalError(loc, "expr must contain a designator" ); |
51 | } |
52 | template <typename T> |
53 | mlir::Type genDesignator(const Fortran::evaluate::Expr<T> &expr) { |
54 | using ExprVariant = decltype(Fortran::evaluate::Expr<T>::u); |
55 | using Designator = Fortran::evaluate::Designator<T>; |
56 | if constexpr (Fortran::common::HasMember<Designator, ExprVariant>) { |
57 | const auto &designator = std::get<Designator>(expr.u); |
58 | return Fortran::common::visit([&](const auto &x) { return gen(x); }, |
59 | designator.u); |
60 | } else { |
61 | return Fortran::common::visit( |
62 | [&](const auto &x) { return genDesignator(x); }, expr.u); |
63 | } |
64 | } |
65 | |
66 | // The gen(X) methods visit X to lower its base and subscripts and return the |
67 | // type of X elements. |
68 | |
69 | mlir::Type gen(const Fortran::evaluate::DataRef &dataRef) { |
70 | return Fortran::common::visit( |
71 | [&](const auto &ref) -> mlir::Type { return gen(ref); }, dataRef.u); |
72 | } |
73 | |
74 | mlir::Type gen(const Fortran::evaluate::SymbolRef &symRef) { |
75 | // Never visited because expr lowering is used to lowered the ranked |
76 | // ArrayRef. |
77 | fir::emitFatalError( |
78 | loc, "expected at least one ArrayRef with vector susbcripts" ); |
79 | } |
80 | |
81 | mlir::Type gen(const Fortran::evaluate::Substring &substring) { |
82 | // StaticDataObject::Pointer bases are constants and cannot be |
83 | // subscripted, so the base must be a DataRef here. |
84 | mlir::Type baseElementType = |
85 | gen(std::get<Fortran::evaluate::DataRef>(substring.parent())); |
86 | fir::FirOpBuilder &builder = converter.getFirOpBuilder(); |
87 | mlir::Type idxTy = builder.getIndexType(); |
88 | mlir::Value lb = genScalarValue(substring.lower()); |
89 | substringBounds.emplace_back(builder.createConvert(loc, idxTy, lb)); |
90 | if (const auto &ubExpr = substring.upper()) { |
91 | mlir::Value ub = genScalarValue(*ubExpr); |
92 | substringBounds.emplace_back(builder.createConvert(loc, idxTy, ub)); |
93 | } |
94 | return baseElementType; |
95 | } |
96 | |
97 | mlir::Type gen(const Fortran::evaluate::ComplexPart &complexPart) { |
98 | auto complexType = gen(complexPart.complex()); |
99 | fir::FirOpBuilder &builder = converter.getFirOpBuilder(); |
100 | mlir::Type i32Ty = builder.getI32Type(); // llvm's GEP requires i32 |
101 | mlir::Value offset = builder.createIntegerConstant( |
102 | loc, i32Ty, |
103 | complexPart.part() == Fortran::evaluate::ComplexPart::Part::RE ? 0 : 1); |
104 | componentPath.emplace_back(offset); |
105 | return fir::factory::Complex{builder, loc}.getComplexPartType(complexType); |
106 | } |
107 | |
108 | mlir::Type gen(const Fortran::evaluate::Component &component) { |
109 | auto recTy = mlir::cast<fir::RecordType>(gen(component.base())); |
110 | const Fortran::semantics::Symbol &componentSymbol = |
111 | component.GetLastSymbol(); |
112 | // Parent components will not be found here, they are not part |
113 | // of the FIR type and cannot be used in the path yet. |
114 | if (componentSymbol.test(Fortran::semantics::Symbol::Flag::ParentComp)) |
115 | TODO(loc, "reference to parent component" ); |
116 | mlir::Type fldTy = fir::FieldType::get(&converter.getMLIRContext()); |
117 | llvm::StringRef componentName = toStringRef(componentSymbol.name()); |
118 | // Parameters threading in field_index is not yet very clear. We only |
119 | // have the ones of the ranked array ref at hand, but it looks like |
120 | // the fir.field_index expects the one of the direct base. |
121 | if (recTy.getNumLenParams() != 0) |
122 | TODO(loc, "threading length parameters in field index op" ); |
123 | fir::FirOpBuilder &builder = converter.getFirOpBuilder(); |
124 | componentPath.emplace_back(builder.create<fir::FieldIndexOp>( |
125 | loc, fldTy, componentName, recTy, /*typeParams*/ std::nullopt)); |
126 | return fir::unwrapSequenceType(recTy.getType(componentName)); |
127 | } |
128 | |
129 | mlir::Type gen(const Fortran::evaluate::ArrayRef &arrayRef) { |
130 | auto isTripletOrVector = |
131 | [](const Fortran::evaluate::Subscript &subscript) -> bool { |
132 | return Fortran::common::visit( |
133 | Fortran::common::visitors{ |
134 | [](const Fortran::evaluate::IndirectSubscriptIntegerExpr &expr) { |
135 | return expr.value().Rank() != 0; |
136 | }, |
137 | [&](const Fortran::evaluate::Triplet &) { return true; }}, |
138 | subscript.u); |
139 | }; |
140 | if (llvm::any_of(arrayRef.subscript(), isTripletOrVector)) |
141 | return genRankedArrayRefSubscriptAndBase(arrayRef); |
142 | |
143 | // This is a scalar ArrayRef (only scalar indexes), collect the indexes and |
144 | // visit the base that must contain another arrayRef with the vector |
145 | // subscript. |
146 | mlir::Type elementType = gen(namedEntityToDataRef(arrayRef.base())); |
147 | for (const Fortran::evaluate::Subscript &subscript : arrayRef.subscript()) { |
148 | const auto &expr = |
149 | std::get<Fortran::evaluate::IndirectSubscriptIntegerExpr>( |
150 | subscript.u); |
151 | componentPath.emplace_back(genScalarValue(expr.value())); |
152 | } |
153 | return elementType; |
154 | } |
155 | |
156 | /// Lower the subscripts and base of the ArrayRef that is an array (there must |
157 | /// be one since there is a vector subscript, and there can only be one |
158 | /// according to C925). |
159 | mlir::Type genRankedArrayRefSubscriptAndBase( |
160 | const Fortran::evaluate::ArrayRef &arrayRef) { |
161 | // Lower the save the base |
162 | Fortran::lower::SomeExpr baseExpr = namedEntityToExpr(arrayRef.base()); |
163 | loweredBase = converter.genExprAddr(baseExpr, stmtCtx); |
164 | // Lower and save the subscripts |
165 | fir::FirOpBuilder &builder = converter.getFirOpBuilder(); |
166 | mlir::Type idxTy = builder.getIndexType(); |
167 | mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1); |
168 | for (const auto &subscript : llvm::enumerate(arrayRef.subscript())) { |
169 | Fortran::common::visit( |
170 | Fortran::common::visitors{ |
171 | [&](const Fortran::evaluate::IndirectSubscriptIntegerExpr &expr) { |
172 | if (expr.value().Rank() == 0) { |
173 | // Simple scalar subscript |
174 | loweredSubscripts.emplace_back(genScalarValue(expr.value())); |
175 | } else { |
176 | // Vector subscript. |
177 | // Remove conversion if any to avoid temp creation that may |
178 | // have been added by the front-end to avoid the creation of a |
179 | // temp array value. |
180 | auto vector = converter.genExprAddr( |
181 | ignoreEvConvert(expr.value()), stmtCtx); |
182 | mlir::Value size = |
183 | fir::factory::readExtent(builder, loc, vector, /*dim=*/0); |
184 | size = builder.createConvert(loc, idxTy, size); |
185 | loweredSubscripts.emplace_back( |
186 | LoweredVectorSubscript{std::move(vector), size}); |
187 | } |
188 | }, |
189 | [&](const Fortran::evaluate::Triplet &triplet) { |
190 | mlir::Value lb, ub; |
191 | if (const auto &lbExpr = triplet.lower()) |
192 | lb = genScalarValue(*lbExpr); |
193 | else |
194 | lb = fir::factory::readLowerBound(builder, loc, loweredBase, |
195 | subscript.index(), one); |
196 | if (const auto &ubExpr = triplet.upper()) |
197 | ub = genScalarValue(*ubExpr); |
198 | else |
199 | ub = fir::factory::readExtent(builder, loc, loweredBase, |
200 | subscript.index()); |
201 | lb = builder.createConvert(loc, idxTy, lb); |
202 | ub = builder.createConvert(loc, idxTy, ub); |
203 | mlir::Value stride = genScalarValue(triplet.stride()); |
204 | stride = builder.createConvert(loc, idxTy, stride); |
205 | loweredSubscripts.emplace_back(LoweredTriplet{lb, ub, stride}); |
206 | }, |
207 | }, |
208 | subscript.value().u); |
209 | } |
210 | return fir::unwrapSequenceType( |
211 | fir::unwrapPassByRefType(fir::getBase(loweredBase).getType())); |
212 | } |
213 | |
214 | mlir::Type gen(const Fortran::evaluate::CoarrayRef &) { |
215 | // Is this possible/legal ? |
216 | TODO(loc, "coarray: reference to coarray object with vector subscript in " |
217 | "IO input" ); |
218 | } |
219 | |
220 | template <typename A> |
221 | mlir::Value genScalarValue(const A &expr) { |
222 | return fir::getBase(converter.genExprValue(toEvExpr(expr), stmtCtx)); |
223 | } |
224 | |
225 | Fortran::evaluate::DataRef |
226 | namedEntityToDataRef(const Fortran::evaluate::NamedEntity &namedEntity) { |
227 | if (namedEntity.IsSymbol()) |
228 | return Fortran::evaluate::DataRef{namedEntity.GetFirstSymbol()}; |
229 | return Fortran::evaluate::DataRef{namedEntity.GetComponent()}; |
230 | } |
231 | |
232 | Fortran::lower::SomeExpr |
233 | namedEntityToExpr(const Fortran::evaluate::NamedEntity &namedEntity) { |
234 | return Fortran::evaluate::AsGenericExpr(namedEntityToDataRef(namedEntity)) |
235 | .value(); |
236 | } |
237 | |
238 | Fortran::lower::AbstractConverter &converter; |
239 | Fortran::lower::StatementContext &stmtCtx; |
240 | mlir::Location loc; |
241 | /// Elements of VectorSubscriptBox being built. |
242 | fir::ExtendedValue loweredBase; |
243 | llvm::SmallVector<LoweredSubscript, 16> loweredSubscripts; |
244 | llvm::SmallVector<mlir::Value> componentPath; |
245 | MaybeSubstring substringBounds; |
246 | mlir::Type elementType; |
247 | }; |
248 | } // namespace |
249 | |
250 | Fortran::lower::VectorSubscriptBox Fortran::lower::genVectorSubscriptBox( |
251 | mlir::Location loc, Fortran::lower::AbstractConverter &converter, |
252 | Fortran::lower::StatementContext &stmtCtx, |
253 | const Fortran::lower::SomeExpr &expr) { |
254 | return VectorSubscriptBoxBuilder(loc, converter, stmtCtx).gen(expr); |
255 | } |
256 | |
257 | template <typename LoopType, typename Generator> |
258 | mlir::Value Fortran::lower::VectorSubscriptBox::loopOverElementsBase( |
259 | fir::FirOpBuilder &builder, mlir::Location loc, |
260 | const Generator &elementalGenerator, |
261 | [[maybe_unused]] mlir::Value initialCondition) { |
262 | mlir::Value shape = builder.createShape(loc, loweredBase); |
263 | mlir::Value slice = createSlice(builder, loc); |
264 | |
265 | // Create loop nest for triplets and vector subscripts in column |
266 | // major order. |
267 | llvm::SmallVector<mlir::Value> inductionVariables; |
268 | LoopType outerLoop; |
269 | for (auto [lb, ub, step] : genLoopBounds(builder, loc)) { |
270 | LoopType loop; |
271 | if constexpr (std::is_same_v<LoopType, fir::IterWhileOp>) { |
272 | loop = |
273 | builder.create<fir::IterWhileOp>(loc, lb, ub, step, initialCondition); |
274 | initialCondition = loop.getIterateVar(); |
275 | if (!outerLoop) |
276 | outerLoop = loop; |
277 | else |
278 | builder.create<fir::ResultOp>(loc, loop.getResult(0)); |
279 | } else { |
280 | loop = |
281 | builder.create<fir::DoLoopOp>(loc, lb, ub, step, /*unordered=*/false); |
282 | if (!outerLoop) |
283 | outerLoop = loop; |
284 | } |
285 | builder.setInsertionPointToStart(loop.getBody()); |
286 | inductionVariables.push_back(loop.getInductionVar()); |
287 | } |
288 | assert(outerLoop && !inductionVariables.empty() && |
289 | "at least one loop should be created" ); |
290 | |
291 | fir::ExtendedValue elem = |
292 | getElementAt(builder, loc, shape, slice, inductionVariables); |
293 | |
294 | if constexpr (std::is_same_v<LoopType, fir::IterWhileOp>) { |
295 | auto res = elementalGenerator(elem); |
296 | builder.create<fir::ResultOp>(loc, res); |
297 | builder.setInsertionPointAfter(outerLoop); |
298 | return outerLoop.getResult(0); |
299 | } else { |
300 | elementalGenerator(elem); |
301 | builder.setInsertionPointAfter(outerLoop); |
302 | return {}; |
303 | } |
304 | } |
305 | |
306 | void Fortran::lower::VectorSubscriptBox::loopOverElements( |
307 | fir::FirOpBuilder &builder, mlir::Location loc, |
308 | const ElementalGenerator &elementalGenerator) { |
309 | mlir::Value initialCondition; |
310 | loopOverElementsBase<fir::DoLoopOp, ElementalGenerator>( |
311 | builder, loc, elementalGenerator, initialCondition); |
312 | } |
313 | |
314 | mlir::Value Fortran::lower::VectorSubscriptBox::loopOverElementsWhile( |
315 | fir::FirOpBuilder &builder, mlir::Location loc, |
316 | const ElementalGeneratorWithBoolReturn &elementalGenerator, |
317 | mlir::Value initialCondition) { |
318 | return loopOverElementsBase<fir::IterWhileOp, |
319 | ElementalGeneratorWithBoolReturn>( |
320 | builder, loc, elementalGenerator, initialCondition); |
321 | } |
322 | |
323 | mlir::Value |
324 | Fortran::lower::VectorSubscriptBox::createSlice(fir::FirOpBuilder &builder, |
325 | mlir::Location loc) { |
326 | mlir::Type idxTy = builder.getIndexType(); |
327 | llvm::SmallVector<mlir::Value> triples; |
328 | mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1); |
329 | auto undef = builder.create<fir::UndefOp>(loc, idxTy); |
330 | for (const LoweredSubscript &subscript : loweredSubscripts) |
331 | Fortran::common::visit(Fortran::common::visitors{ |
332 | [&](const LoweredTriplet &triplet) { |
333 | triples.emplace_back(triplet.lb); |
334 | triples.emplace_back(triplet.ub); |
335 | triples.emplace_back(triplet.stride); |
336 | }, |
337 | [&](const LoweredVectorSubscript &vector) { |
338 | triples.emplace_back(one); |
339 | triples.emplace_back(vector.size); |
340 | triples.emplace_back(one); |
341 | }, |
342 | [&](const mlir::Value &i) { |
343 | triples.emplace_back(i); |
344 | triples.emplace_back(undef); |
345 | triples.emplace_back(undef); |
346 | }, |
347 | }, |
348 | subscript); |
349 | return builder.create<fir::SliceOp>(loc, triples, componentPath); |
350 | } |
351 | |
352 | llvm::SmallVector<std::tuple<mlir::Value, mlir::Value, mlir::Value>> |
353 | Fortran::lower::VectorSubscriptBox::genLoopBounds(fir::FirOpBuilder &builder, |
354 | mlir::Location loc) { |
355 | mlir::Type idxTy = builder.getIndexType(); |
356 | mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1); |
357 | mlir::Value zero = builder.createIntegerConstant(loc, idxTy, 0); |
358 | llvm::SmallVector<std::tuple<mlir::Value, mlir::Value, mlir::Value>> bounds; |
359 | size_t dimension = loweredSubscripts.size(); |
360 | for (const LoweredSubscript &subscript : llvm::reverse(loweredSubscripts)) { |
361 | --dimension; |
362 | if (std::holds_alternative<mlir::Value>(subscript)) |
363 | continue; |
364 | mlir::Value lb, ub, step; |
365 | if (const auto *triplet = std::get_if<LoweredTriplet>(&subscript)) { |
366 | mlir::Value extent = builder.genExtentFromTriplet( |
367 | loc, triplet->lb, triplet->ub, triplet->stride, idxTy); |
368 | mlir::Value baseLb = fir::factory::readLowerBound( |
369 | builder, loc, loweredBase, dimension, one); |
370 | baseLb = builder.createConvert(loc, idxTy, baseLb); |
371 | lb = baseLb; |
372 | ub = builder.create<mlir::arith::SubIOp>(loc, idxTy, extent, one); |
373 | ub = builder.create<mlir::arith::AddIOp>(loc, idxTy, ub, baseLb); |
374 | step = one; |
375 | } else { |
376 | const auto &vector = std::get<LoweredVectorSubscript>(subscript); |
377 | lb = zero; |
378 | ub = builder.create<mlir::arith::SubIOp>(loc, idxTy, vector.size, one); |
379 | step = one; |
380 | } |
381 | bounds.emplace_back(lb, ub, step); |
382 | } |
383 | return bounds; |
384 | } |
385 | |
386 | fir::ExtendedValue Fortran::lower::VectorSubscriptBox::getElementAt( |
387 | fir::FirOpBuilder &builder, mlir::Location loc, mlir::Value shape, |
388 | mlir::Value slice, mlir::ValueRange inductionVariables) { |
389 | /// Generate the indexes for the array_coor inside the loops. |
390 | mlir::Type idxTy = builder.getIndexType(); |
391 | llvm::SmallVector<mlir::Value> indexes; |
392 | size_t inductionIdx = inductionVariables.size() - 1; |
393 | for (const LoweredSubscript &subscript : loweredSubscripts) |
394 | Fortran::common::visit( |
395 | Fortran::common::visitors{ |
396 | [&](const LoweredTriplet &triplet) { |
397 | indexes.emplace_back(inductionVariables[inductionIdx--]); |
398 | }, |
399 | [&](const LoweredVectorSubscript &vector) { |
400 | mlir::Value vecIndex = inductionVariables[inductionIdx--]; |
401 | mlir::Value vecBase = fir::getBase(vector.vector); |
402 | mlir::Type vecEleTy = fir::unwrapSequenceType( |
403 | fir::unwrapPassByRefType(vecBase.getType())); |
404 | mlir::Type refTy = builder.getRefType(vecEleTy); |
405 | auto vecEltRef = builder.create<fir::CoordinateOp>( |
406 | loc, refTy, vecBase, vecIndex); |
407 | auto vecElt = |
408 | builder.create<fir::LoadOp>(loc, vecEleTy, vecEltRef); |
409 | indexes.emplace_back(builder.createConvert(loc, idxTy, vecElt)); |
410 | }, |
411 | [&](const mlir::Value &i) { |
412 | indexes.emplace_back(builder.createConvert(loc, idxTy, i)); |
413 | }, |
414 | }, |
415 | subscript); |
416 | mlir::Type refTy = builder.getRefType(getElementType()); |
417 | auto elementAddr = builder.create<fir::ArrayCoorOp>( |
418 | loc, refTy, fir::getBase(loweredBase), shape, slice, indexes, |
419 | fir::getTypeParams(loweredBase)); |
420 | fir::ExtendedValue element = fir::factory::arraySectionElementToExtendedValue( |
421 | builder, loc, loweredBase, elementAddr, slice); |
422 | if (!substringBounds.empty()) { |
423 | const fir::CharBoxValue *charBox = element.getCharBox(); |
424 | assert(charBox && "substring requires CharBox base" ); |
425 | fir::factory::CharacterExprHelper helper{builder, loc}; |
426 | return helper.createSubstring(*charBox, substringBounds); |
427 | } |
428 | return element; |
429 | } |
430 | |