1 | //===-- VectorSubscripts.cpp -- Vector subscripts tools -------------------===// |
2 | // |
3 | // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | // |
9 | // Coding style: https://mlir.llvm.org/getting_started/DeveloperGuide/ |
10 | // |
11 | //===----------------------------------------------------------------------===// |
12 | |
13 | #include "flang/Lower/VectorSubscripts.h" |
14 | #include "flang/Lower/AbstractConverter.h" |
15 | #include "flang/Lower/Support/Utils.h" |
16 | #include "flang/Optimizer/Builder/Character.h" |
17 | #include "flang/Optimizer/Builder/Complex.h" |
18 | #include "flang/Optimizer/Builder/FIRBuilder.h" |
19 | #include "flang/Optimizer/Builder/Todo.h" |
20 | #include "flang/Semantics/expression.h" |
21 | |
22 | namespace { |
23 | /// Helper class to lower a designator containing vector subscripts into a |
24 | /// lowered representation that can be worked with. |
25 | class VectorSubscriptBoxBuilder { |
26 | public: |
27 | VectorSubscriptBoxBuilder(mlir::Location loc, |
28 | Fortran::lower::AbstractConverter &converter, |
29 | Fortran::lower::StatementContext &stmtCtx) |
30 | : converter{converter}, stmtCtx{stmtCtx}, loc{loc} {} |
31 | |
32 | Fortran::lower::VectorSubscriptBox gen(const Fortran::lower::SomeExpr &expr) { |
33 | elementType = genDesignator(expr); |
34 | return Fortran::lower::VectorSubscriptBox( |
35 | std::move(loweredBase), std::move(loweredSubscripts), |
36 | std::move(componentPath), substringBounds, elementType); |
37 | } |
38 | |
39 | private: |
40 | using LoweredVectorSubscript = |
41 | Fortran::lower::VectorSubscriptBox::LoweredVectorSubscript; |
42 | using LoweredTriplet = Fortran::lower::VectorSubscriptBox::LoweredTriplet; |
43 | using LoweredSubscript = Fortran::lower::VectorSubscriptBox::LoweredSubscript; |
44 | using MaybeSubstring = Fortran::lower::VectorSubscriptBox::MaybeSubstring; |
45 | |
46 | /// genDesignator unwraps a Designator<T> and calls `gen` on what the |
47 | /// designator actually contains. |
48 | template <typename A> |
49 | mlir::Type genDesignator(const A &) { |
50 | fir::emitFatalError(loc, "expr must contain a designator" ); |
51 | } |
52 | template <typename T> |
53 | mlir::Type genDesignator(const Fortran::evaluate::Expr<T> &expr) { |
54 | using ExprVariant = decltype(Fortran::evaluate::Expr<T>::u); |
55 | using Designator = Fortran::evaluate::Designator<T>; |
56 | if constexpr (Fortran::common::HasMember<Designator, ExprVariant>) { |
57 | const auto &designator = std::get<Designator>(expr.u); |
58 | return std::visit([&](const auto &x) { return gen(x); }, designator.u); |
59 | } else { |
60 | return std::visit([&](const auto &x) { return genDesignator(x); }, |
61 | expr.u); |
62 | } |
63 | } |
64 | |
65 | // The gen(X) methods visit X to lower its base and subscripts and return the |
66 | // type of X elements. |
67 | |
68 | mlir::Type gen(const Fortran::evaluate::DataRef &dataRef) { |
69 | return std::visit([&](const auto &ref) -> mlir::Type { return gen(ref); }, |
70 | dataRef.u); |
71 | } |
72 | |
73 | mlir::Type gen(const Fortran::evaluate::SymbolRef &symRef) { |
74 | // Never visited because expr lowering is used to lowered the ranked |
75 | // ArrayRef. |
76 | fir::emitFatalError( |
77 | loc, "expected at least one ArrayRef with vector susbcripts" ); |
78 | } |
79 | |
80 | mlir::Type gen(const Fortran::evaluate::Substring &substring) { |
81 | // StaticDataObject::Pointer bases are constants and cannot be |
82 | // subscripted, so the base must be a DataRef here. |
83 | mlir::Type baseElementType = |
84 | gen(std::get<Fortran::evaluate::DataRef>(substring.parent())); |
85 | fir::FirOpBuilder &builder = converter.getFirOpBuilder(); |
86 | mlir::Type idxTy = builder.getIndexType(); |
87 | mlir::Value lb = genScalarValue(substring.lower()); |
88 | substringBounds.emplace_back(builder.createConvert(loc, idxTy, lb)); |
89 | if (const auto &ubExpr = substring.upper()) { |
90 | mlir::Value ub = genScalarValue(*ubExpr); |
91 | substringBounds.emplace_back(builder.createConvert(loc, idxTy, ub)); |
92 | } |
93 | return baseElementType; |
94 | } |
95 | |
96 | mlir::Type gen(const Fortran::evaluate::ComplexPart &complexPart) { |
97 | auto complexType = gen(complexPart.complex()); |
98 | fir::FirOpBuilder &builder = converter.getFirOpBuilder(); |
99 | mlir::Type i32Ty = builder.getI32Type(); // llvm's GEP requires i32 |
100 | mlir::Value offset = builder.createIntegerConstant( |
101 | loc, i32Ty, |
102 | complexPart.part() == Fortran::evaluate::ComplexPart::Part::RE ? 0 : 1); |
103 | componentPath.emplace_back(offset); |
104 | return fir::factory::Complex{builder, loc}.getComplexPartType(complexType); |
105 | } |
106 | |
107 | mlir::Type gen(const Fortran::evaluate::Component &component) { |
108 | auto recTy = gen(component.base()).cast<fir::RecordType>(); |
109 | const Fortran::semantics::Symbol &componentSymbol = |
110 | component.GetLastSymbol(); |
111 | // Parent components will not be found here, they are not part |
112 | // of the FIR type and cannot be used in the path yet. |
113 | if (componentSymbol.test(Fortran::semantics::Symbol::Flag::ParentComp)) |
114 | TODO(loc, "reference to parent component" ); |
115 | mlir::Type fldTy = fir::FieldType::get(&converter.getMLIRContext()); |
116 | llvm::StringRef componentName = toStringRef(componentSymbol.name()); |
117 | // Parameters threading in field_index is not yet very clear. We only |
118 | // have the ones of the ranked array ref at hand, but it looks like |
119 | // the fir.field_index expects the one of the direct base. |
120 | if (recTy.getNumLenParams() != 0) |
121 | TODO(loc, "threading length parameters in field index op" ); |
122 | fir::FirOpBuilder &builder = converter.getFirOpBuilder(); |
123 | componentPath.emplace_back(builder.create<fir::FieldIndexOp>( |
124 | loc, fldTy, componentName, recTy, /*typeParams*/ std::nullopt)); |
125 | return fir::unwrapSequenceType(recTy.getType(componentName)); |
126 | } |
127 | |
128 | mlir::Type gen(const Fortran::evaluate::ArrayRef &arrayRef) { |
129 | auto isTripletOrVector = |
130 | [](const Fortran::evaluate::Subscript &subscript) -> bool { |
131 | return std::visit( |
132 | Fortran::common::visitors{ |
133 | [](const Fortran::evaluate::IndirectSubscriptIntegerExpr &expr) { |
134 | return expr.value().Rank() != 0; |
135 | }, |
136 | [&](const Fortran::evaluate::Triplet &) { return true; }}, |
137 | subscript.u); |
138 | }; |
139 | if (llvm::any_of(arrayRef.subscript(), isTripletOrVector)) |
140 | return genRankedArrayRefSubscriptAndBase(arrayRef); |
141 | |
142 | // This is a scalar ArrayRef (only scalar indexes), collect the indexes and |
143 | // visit the base that must contain another arrayRef with the vector |
144 | // subscript. |
145 | mlir::Type elementType = gen(namedEntityToDataRef(arrayRef.base())); |
146 | for (const Fortran::evaluate::Subscript &subscript : arrayRef.subscript()) { |
147 | const auto &expr = |
148 | std::get<Fortran::evaluate::IndirectSubscriptIntegerExpr>( |
149 | subscript.u); |
150 | componentPath.emplace_back(genScalarValue(expr.value())); |
151 | } |
152 | return elementType; |
153 | } |
154 | |
155 | /// Lower the subscripts and base of the ArrayRef that is an array (there must |
156 | /// be one since there is a vector subscript, and there can only be one |
157 | /// according to C925). |
158 | mlir::Type genRankedArrayRefSubscriptAndBase( |
159 | const Fortran::evaluate::ArrayRef &arrayRef) { |
160 | // Lower the save the base |
161 | Fortran::lower::SomeExpr baseExpr = namedEntityToExpr(arrayRef.base()); |
162 | loweredBase = converter.genExprAddr(baseExpr, stmtCtx); |
163 | // Lower and save the subscripts |
164 | fir::FirOpBuilder &builder = converter.getFirOpBuilder(); |
165 | mlir::Type idxTy = builder.getIndexType(); |
166 | mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1); |
167 | for (const auto &subscript : llvm::enumerate(arrayRef.subscript())) { |
168 | std::visit( |
169 | Fortran::common::visitors{ |
170 | [&](const Fortran::evaluate::IndirectSubscriptIntegerExpr &expr) { |
171 | if (expr.value().Rank() == 0) { |
172 | // Simple scalar subscript |
173 | loweredSubscripts.emplace_back(genScalarValue(expr.value())); |
174 | } else { |
175 | // Vector subscript. |
176 | // Remove conversion if any to avoid temp creation that may |
177 | // have been added by the front-end to avoid the creation of a |
178 | // temp array value. |
179 | auto vector = converter.genExprAddr( |
180 | ignoreEvConvert(expr.value()), stmtCtx); |
181 | mlir::Value size = |
182 | fir::factory::readExtent(builder, loc, vector, /*dim=*/0); |
183 | size = builder.createConvert(loc, idxTy, size); |
184 | loweredSubscripts.emplace_back( |
185 | LoweredVectorSubscript{std::move(vector), size}); |
186 | } |
187 | }, |
188 | [&](const Fortran::evaluate::Triplet &triplet) { |
189 | mlir::Value lb, ub; |
190 | if (const auto &lbExpr = triplet.lower()) |
191 | lb = genScalarValue(*lbExpr); |
192 | else |
193 | lb = fir::factory::readLowerBound(builder, loc, loweredBase, |
194 | subscript.index(), one); |
195 | if (const auto &ubExpr = triplet.upper()) |
196 | ub = genScalarValue(*ubExpr); |
197 | else |
198 | ub = fir::factory::readExtent(builder, loc, loweredBase, |
199 | subscript.index()); |
200 | lb = builder.createConvert(loc, idxTy, lb); |
201 | ub = builder.createConvert(loc, idxTy, ub); |
202 | mlir::Value stride = genScalarValue(triplet.stride()); |
203 | stride = builder.createConvert(loc, idxTy, stride); |
204 | loweredSubscripts.emplace_back(LoweredTriplet{lb, ub, stride}); |
205 | }, |
206 | }, |
207 | subscript.value().u); |
208 | } |
209 | return fir::unwrapSequenceType( |
210 | fir::unwrapPassByRefType(fir::getBase(loweredBase).getType())); |
211 | } |
212 | |
213 | mlir::Type gen(const Fortran::evaluate::CoarrayRef &) { |
214 | // Is this possible/legal ? |
215 | TODO(loc, "coarray: reference to coarray object with vector subscript in " |
216 | "IO input" ); |
217 | } |
218 | |
219 | template <typename A> |
220 | mlir::Value genScalarValue(const A &expr) { |
221 | return fir::getBase(converter.genExprValue(toEvExpr(expr), stmtCtx)); |
222 | } |
223 | |
224 | Fortran::evaluate::DataRef |
225 | namedEntityToDataRef(const Fortran::evaluate::NamedEntity &namedEntity) { |
226 | if (namedEntity.IsSymbol()) |
227 | return Fortran::evaluate::DataRef{namedEntity.GetFirstSymbol()}; |
228 | return Fortran::evaluate::DataRef{namedEntity.GetComponent()}; |
229 | } |
230 | |
231 | Fortran::lower::SomeExpr |
232 | namedEntityToExpr(const Fortran::evaluate::NamedEntity &namedEntity) { |
233 | return Fortran::evaluate::AsGenericExpr(namedEntityToDataRef(namedEntity)) |
234 | .value(); |
235 | } |
236 | |
237 | Fortran::lower::AbstractConverter &converter; |
238 | Fortran::lower::StatementContext &stmtCtx; |
239 | mlir::Location loc; |
240 | /// Elements of VectorSubscriptBox being built. |
241 | fir::ExtendedValue loweredBase; |
242 | llvm::SmallVector<LoweredSubscript, 16> loweredSubscripts; |
243 | llvm::SmallVector<mlir::Value> componentPath; |
244 | MaybeSubstring substringBounds; |
245 | mlir::Type elementType; |
246 | }; |
247 | } // namespace |
248 | |
249 | Fortran::lower::VectorSubscriptBox Fortran::lower::genVectorSubscriptBox( |
250 | mlir::Location loc, Fortran::lower::AbstractConverter &converter, |
251 | Fortran::lower::StatementContext &stmtCtx, |
252 | const Fortran::lower::SomeExpr &expr) { |
253 | return VectorSubscriptBoxBuilder(loc, converter, stmtCtx).gen(expr); |
254 | } |
255 | |
256 | template <typename LoopType, typename Generator> |
257 | mlir::Value Fortran::lower::VectorSubscriptBox::loopOverElementsBase( |
258 | fir::FirOpBuilder &builder, mlir::Location loc, |
259 | const Generator &elementalGenerator, |
260 | [[maybe_unused]] mlir::Value initialCondition) { |
261 | mlir::Value shape = builder.createShape(loc, loweredBase); |
262 | mlir::Value slice = createSlice(builder, loc); |
263 | |
264 | // Create loop nest for triplets and vector subscripts in column |
265 | // major order. |
266 | llvm::SmallVector<mlir::Value> inductionVariables; |
267 | LoopType outerLoop; |
268 | for (auto [lb, ub, step] : genLoopBounds(builder, loc)) { |
269 | LoopType loop; |
270 | if constexpr (std::is_same_v<LoopType, fir::IterWhileOp>) { |
271 | loop = |
272 | builder.create<fir::IterWhileOp>(loc, lb, ub, step, initialCondition); |
273 | initialCondition = loop.getIterateVar(); |
274 | if (!outerLoop) |
275 | outerLoop = loop; |
276 | else |
277 | builder.create<fir::ResultOp>(loc, loop.getResult(0)); |
278 | } else { |
279 | loop = |
280 | builder.create<fir::DoLoopOp>(loc, lb, ub, step, /*unordered=*/false); |
281 | if (!outerLoop) |
282 | outerLoop = loop; |
283 | } |
284 | builder.setInsertionPointToStart(loop.getBody()); |
285 | inductionVariables.push_back(loop.getInductionVar()); |
286 | } |
287 | assert(outerLoop && !inductionVariables.empty() && |
288 | "at least one loop should be created" ); |
289 | |
290 | fir::ExtendedValue elem = |
291 | getElementAt(builder, loc, shape, slice, inductionVariables); |
292 | |
293 | if constexpr (std::is_same_v<LoopType, fir::IterWhileOp>) { |
294 | auto res = elementalGenerator(elem); |
295 | builder.create<fir::ResultOp>(loc, res); |
296 | builder.setInsertionPointAfter(outerLoop); |
297 | return outerLoop.getResult(0); |
298 | } else { |
299 | elementalGenerator(elem); |
300 | builder.setInsertionPointAfter(outerLoop); |
301 | return {}; |
302 | } |
303 | } |
304 | |
305 | void Fortran::lower::VectorSubscriptBox::loopOverElements( |
306 | fir::FirOpBuilder &builder, mlir::Location loc, |
307 | const ElementalGenerator &elementalGenerator) { |
308 | mlir::Value initialCondition; |
309 | loopOverElementsBase<fir::DoLoopOp, ElementalGenerator>( |
310 | builder, loc, elementalGenerator, initialCondition); |
311 | } |
312 | |
313 | mlir::Value Fortran::lower::VectorSubscriptBox::loopOverElementsWhile( |
314 | fir::FirOpBuilder &builder, mlir::Location loc, |
315 | const ElementalGeneratorWithBoolReturn &elementalGenerator, |
316 | mlir::Value initialCondition) { |
317 | return loopOverElementsBase<fir::IterWhileOp, |
318 | ElementalGeneratorWithBoolReturn>( |
319 | builder, loc, elementalGenerator, initialCondition); |
320 | } |
321 | |
322 | mlir::Value |
323 | Fortran::lower::VectorSubscriptBox::createSlice(fir::FirOpBuilder &builder, |
324 | mlir::Location loc) { |
325 | mlir::Type idxTy = builder.getIndexType(); |
326 | llvm::SmallVector<mlir::Value> triples; |
327 | mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1); |
328 | auto undef = builder.create<fir::UndefOp>(loc, idxTy); |
329 | for (const LoweredSubscript &subscript : loweredSubscripts) |
330 | std::visit(Fortran::common::visitors{ |
331 | [&](const LoweredTriplet &triplet) { |
332 | triples.emplace_back(triplet.lb); |
333 | triples.emplace_back(triplet.ub); |
334 | triples.emplace_back(triplet.stride); |
335 | }, |
336 | [&](const LoweredVectorSubscript &vector) { |
337 | triples.emplace_back(one); |
338 | triples.emplace_back(vector.size); |
339 | triples.emplace_back(one); |
340 | }, |
341 | [&](const mlir::Value &i) { |
342 | triples.emplace_back(i); |
343 | triples.emplace_back(undef); |
344 | triples.emplace_back(undef); |
345 | }, |
346 | }, |
347 | subscript); |
348 | return builder.create<fir::SliceOp>(loc, triples, componentPath); |
349 | } |
350 | |
351 | llvm::SmallVector<std::tuple<mlir::Value, mlir::Value, mlir::Value>> |
352 | Fortran::lower::VectorSubscriptBox::genLoopBounds(fir::FirOpBuilder &builder, |
353 | mlir::Location loc) { |
354 | mlir::Type idxTy = builder.getIndexType(); |
355 | mlir::Value one = builder.createIntegerConstant(loc, idxTy, 1); |
356 | mlir::Value zero = builder.createIntegerConstant(loc, idxTy, 0); |
357 | llvm::SmallVector<std::tuple<mlir::Value, mlir::Value, mlir::Value>> bounds; |
358 | size_t dimension = loweredSubscripts.size(); |
359 | for (const LoweredSubscript &subscript : llvm::reverse(loweredSubscripts)) { |
360 | --dimension; |
361 | if (std::holds_alternative<mlir::Value>(subscript)) |
362 | continue; |
363 | mlir::Value lb, ub, step; |
364 | if (const auto *triplet = std::get_if<LoweredTriplet>(&subscript)) { |
365 | mlir::Value extent = builder.genExtentFromTriplet( |
366 | loc, triplet->lb, triplet->ub, triplet->stride, idxTy); |
367 | mlir::Value baseLb = fir::factory::readLowerBound( |
368 | builder, loc, loweredBase, dimension, one); |
369 | baseLb = builder.createConvert(loc, idxTy, baseLb); |
370 | lb = baseLb; |
371 | ub = builder.create<mlir::arith::SubIOp>(loc, idxTy, extent, one); |
372 | ub = builder.create<mlir::arith::AddIOp>(loc, idxTy, ub, baseLb); |
373 | step = one; |
374 | } else { |
375 | const auto &vector = std::get<LoweredVectorSubscript>(subscript); |
376 | lb = zero; |
377 | ub = builder.create<mlir::arith::SubIOp>(loc, idxTy, vector.size, one); |
378 | step = one; |
379 | } |
380 | bounds.emplace_back(lb, ub, step); |
381 | } |
382 | return bounds; |
383 | } |
384 | |
385 | fir::ExtendedValue Fortran::lower::VectorSubscriptBox::getElementAt( |
386 | fir::FirOpBuilder &builder, mlir::Location loc, mlir::Value shape, |
387 | mlir::Value slice, mlir::ValueRange inductionVariables) { |
388 | /// Generate the indexes for the array_coor inside the loops. |
389 | mlir::Type idxTy = builder.getIndexType(); |
390 | llvm::SmallVector<mlir::Value> indexes; |
391 | size_t inductionIdx = inductionVariables.size() - 1; |
392 | for (const LoweredSubscript &subscript : loweredSubscripts) |
393 | std::visit(Fortran::common::visitors{ |
394 | [&](const LoweredTriplet &triplet) { |
395 | indexes.emplace_back(inductionVariables[inductionIdx--]); |
396 | }, |
397 | [&](const LoweredVectorSubscript &vector) { |
398 | mlir::Value vecIndex = inductionVariables[inductionIdx--]; |
399 | mlir::Value vecBase = fir::getBase(vector.vector); |
400 | mlir::Type vecEleTy = fir::unwrapSequenceType( |
401 | fir::unwrapPassByRefType(vecBase.getType())); |
402 | mlir::Type refTy = builder.getRefType(vecEleTy); |
403 | auto vecEltRef = builder.create<fir::CoordinateOp>( |
404 | loc, refTy, vecBase, vecIndex); |
405 | auto vecElt = |
406 | builder.create<fir::LoadOp>(loc, vecEleTy, vecEltRef); |
407 | indexes.emplace_back( |
408 | builder.createConvert(loc, idxTy, vecElt)); |
409 | }, |
410 | [&](const mlir::Value &i) { |
411 | indexes.emplace_back(builder.createConvert(loc, idxTy, i)); |
412 | }, |
413 | }, |
414 | subscript); |
415 | mlir::Type refTy = builder.getRefType(getElementType()); |
416 | auto elementAddr = builder.create<fir::ArrayCoorOp>( |
417 | loc, refTy, fir::getBase(loweredBase), shape, slice, indexes, |
418 | fir::getTypeParams(loweredBase)); |
419 | fir::ExtendedValue element = fir::factory::arraySectionElementToExtendedValue( |
420 | builder, loc, loweredBase, elementAddr, slice); |
421 | if (!substringBounds.empty()) { |
422 | const fir::CharBoxValue *charBox = element.getCharBox(); |
423 | assert(charBox && "substring requires CharBox base" ); |
424 | fir::factory::CharacterExprHelper helper{builder, loc}; |
425 | return helper.createSubstring(*charBox, substringBounds); |
426 | } |
427 | return element; |
428 | } |
429 | |