1//===-- OneToNTypeConversion.cpp - Utils for 1:N type conversion-*- C++ -*-===//
2//
3// Licensed under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
9#include "mlir/Transforms/OneToNTypeConversion.h"
10
11#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
12#include "llvm/ADT/SmallSet.h"
13
14#include <unordered_map>
15
16using namespace llvm;
17using namespace mlir;
18
19std::optional<SmallVector<Value>>
20OneToNTypeConverter::materializeTargetConversion(OpBuilder &builder,
21 Location loc,
22 TypeRange resultTypes,
23 Value input) const {
24 for (const OneToNMaterializationCallbackFn &fn :
25 llvm::reverse(C: oneToNTargetMaterializations)) {
26 if (std::optional<SmallVector<Value>> result =
27 fn(builder, resultTypes, input, loc))
28 return *result;
29 }
30 return std::nullopt;
31}
32
33TypeRange OneToNTypeMapping::getConvertedTypes(unsigned originalTypeNo) const {
34 TypeRange convertedTypes = getConvertedTypes();
35 if (auto mapping = getInputMapping(input: originalTypeNo))
36 return convertedTypes.slice(n: mapping->inputNo, m: mapping->size);
37 return {};
38}
39
40ValueRange
41OneToNTypeMapping::getConvertedValues(ValueRange convertedValues,
42 unsigned originalValueNo) const {
43 if (auto mapping = getInputMapping(input: originalValueNo))
44 return convertedValues.slice(n: mapping->inputNo, m: mapping->size);
45 return {};
46}
47
48void OneToNTypeMapping::convertLocation(
49 Value originalValue, unsigned originalValueNo,
50 llvm::SmallVectorImpl<Location> &result) const {
51 if (auto mapping = getInputMapping(input: originalValueNo))
52 result.append(NumInputs: mapping->size, Elt: originalValue.getLoc());
53}
54
55void OneToNTypeMapping::convertLocations(
56 ValueRange originalValues, llvm::SmallVectorImpl<Location> &result) const {
57 assert(originalValues.size() == getOriginalTypes().size());
58 for (auto [i, value] : llvm::enumerate(First&: originalValues))
59 convertLocation(originalValue: value, originalValueNo: i, result);
60}
61
62static bool isIdentityConversion(Type originalType, TypeRange convertedTypes) {
63 return convertedTypes.size() == 1 && convertedTypes[0] == originalType;
64}
65
66bool OneToNTypeMapping::hasNonIdentityConversion() const {
67 // XXX: I think that the original types and the converted types are the same
68 // iff there was no non-identity type conversion. If that is true, the
69 // patterns could actually test whether there is anything useful to do
70 // without having access to the signature conversion.
71 for (auto [i, originalType] : llvm::enumerate(First: originalTypes)) {
72 TypeRange types = getConvertedTypes(originalTypeNo: i);
73 if (!isIdentityConversion(originalType, convertedTypes: types)) {
74 assert(TypeRange(originalTypes) != getConvertedTypes());
75 return true;
76 }
77 }
78 assert(TypeRange(originalTypes) == getConvertedTypes());
79 return false;
80}
81
82namespace {
83enum class CastKind {
84 // Casts block arguments in the target type back to the source type. (If
85 // necessary, this cast becomes an argument materialization.)
86 Argument,
87
88 // Casts other values in the target type back to the source type. (If
89 // necessary, this cast becomes a source materialization.)
90 Source,
91
92 // Casts values in the source type to the target type. (If necessary, this
93 // cast becomes a target materialization.)
94 Target
95};
96} // namespace
97
98/// Mapping of enum values to string values.
99StringRef getCastKindName(CastKind kind) {
100 static const std::unordered_map<CastKind, StringRef> castKindNames = {
101 {CastKind::Argument, "argument"},
102 {CastKind::Source, "source"},
103 {CastKind::Target, "target"}};
104 return castKindNames.at(k: kind);
105}
106
107/// Attribute name that is used to annotate inserted unrealized casts with their
108/// kind (source, argument, or target).
109static const char *const castKindAttrName =
110 "__one-to-n-type-conversion_cast-kind__";
111
112/// Builds an `UnrealizedConversionCastOp` from the given inputs to the given
113/// result types. Returns the result values of the cast.
114static ValueRange buildUnrealizedCast(OpBuilder &builder, TypeRange resultTypes,
115 ValueRange inputs, CastKind kind) {
116 // Special case: 1-to-N conversion with N = 0. No need to build an
117 // UnrealizedConversionCastOp because the op will always be dead.
118 if (resultTypes.empty())
119 return ValueRange();
120
121 // Create cast.
122 Location loc = builder.getUnknownLoc();
123 if (!inputs.empty())
124 loc = inputs.front().getLoc();
125 auto castOp =
126 builder.create<UnrealizedConversionCastOp>(loc, resultTypes, inputs);
127
128 // Store cast kind as attribute.
129 auto kindAttr = StringAttr::get(builder.getContext(), getCastKindName(kind));
130 castOp->setAttr(castKindAttrName, kindAttr);
131
132 return castOp->getResults();
133}
134
135/// Builds one `UnrealizedConversionCastOp` for each of the given original
136/// values using the respective target types given in the provided conversion
137/// mapping and returns the results of these casts. If the conversion mapping of
138/// a value maps a type to itself (i.e., is an identity conversion), then no
139/// cast is inserted and the original value is returned instead.
140/// Note that these unrealized casts are different from target materializations
141/// in that they are *always* inserted, even if they immediately fold away, such
142/// that patterns always see valid intermediate IR, whereas materializations are
143/// only used in the places where the unrealized casts *don't* fold away.
144static SmallVector<Value>
145buildUnrealizedForwardCasts(ValueRange originalValues,
146 OneToNTypeMapping &conversion,
147 RewriterBase &rewriter, CastKind kind) {
148
149 // Convert each operand one by one.
150 SmallVector<Value> convertedValues;
151 convertedValues.reserve(N: conversion.getConvertedTypes().size());
152 for (auto [idx, originalValue] : llvm::enumerate(First&: originalValues)) {
153 TypeRange convertedTypes = conversion.getConvertedTypes(originalTypeNo: idx);
154
155 // Identity conversion: keep operand as is.
156 if (isIdentityConversion(originalType: originalValue.getType(), convertedTypes)) {
157 convertedValues.push_back(Elt: originalValue);
158 continue;
159 }
160
161 // Non-identity conversion: materialize target types.
162 ValueRange castResult =
163 buildUnrealizedCast(builder&: rewriter, resultTypes: convertedTypes, inputs: originalValue, kind);
164 convertedValues.append(in_start: castResult.begin(), in_end: castResult.end());
165 }
166
167 return convertedValues;
168}
169
170/// Builds one `UnrealizedConversionCastOp` for each sequence of the given
171/// original values to one value of the type they originated from, i.e., a
172/// "reverse" conversion from N converted values back to one value of the
173/// original type, using the given (forward) type conversion. If a given value
174/// was mapped to a value of the same type (i.e., the conversion in the mapping
175/// is an identity conversion), then the "converted" value is returned without
176/// cast.
177/// Note that these unrealized casts are different from source materializations
178/// in that they are *always* inserted, even if they immediately fold away, such
179/// that patterns always see valid intermediate IR, whereas materializations are
180/// only used in the places where the unrealized casts *don't* fold away.
181static SmallVector<Value>
182buildUnrealizedBackwardsCasts(ValueRange convertedValues,
183 const OneToNTypeMapping &typeConversion,
184 RewriterBase &rewriter) {
185 assert(typeConversion.getConvertedTypes() == convertedValues.getTypes());
186
187 // Create unrealized cast op for each converted result of the op.
188 SmallVector<Value> recastValues;
189 TypeRange originalTypes = typeConversion.getOriginalTypes();
190 recastValues.reserve(N: originalTypes.size());
191 auto convertedValueIt = convertedValues.begin();
192 for (auto [idx, originalType] : llvm::enumerate(First&: originalTypes)) {
193 TypeRange convertedTypes = typeConversion.getConvertedTypes(originalTypeNo: idx);
194 size_t numConvertedValues = convertedTypes.size();
195 if (isIdentityConversion(originalType, convertedTypes)) {
196 // Identity conversion: take result as is.
197 recastValues.push_back(Elt: *convertedValueIt);
198 } else {
199 // Non-identity conversion: cast back to source type.
200 ValueRange recastValue = buildUnrealizedCast(
201 builder&: rewriter, resultTypes: originalType,
202 inputs: ValueRange{convertedValueIt, convertedValueIt + numConvertedValues},
203 kind: CastKind::Source);
204 assert(recastValue.size() == 1);
205 recastValues.push_back(Elt: recastValue.front());
206 }
207 convertedValueIt += numConvertedValues;
208 }
209
210 return recastValues;
211}
212
213void OneToNPatternRewriter::replaceOp(Operation *op, ValueRange newValues,
214 const OneToNTypeMapping &resultMapping) {
215 // Create a cast back to the original types and replace the results of the
216 // original op with those.
217 assert(newValues.size() == resultMapping.getConvertedTypes().size());
218 assert(op->getResultTypes() == resultMapping.getOriginalTypes());
219 PatternRewriter::InsertionGuard g(*this);
220 setInsertionPointAfter(op);
221 SmallVector<Value> castResults =
222 buildUnrealizedBackwardsCasts(convertedValues: newValues, typeConversion: resultMapping, rewriter&: *this);
223 replaceOp(op, newValues: castResults);
224}
225
226Block *OneToNPatternRewriter::applySignatureConversion(
227 Block *block, OneToNTypeMapping &argumentConversion) {
228 PatternRewriter::InsertionGuard g(*this);
229
230 // Split the block at the beginning to get a new block to use for the
231 // updated signature.
232 SmallVector<Location> locs;
233 argumentConversion.convertLocations(originalValues: block->getArguments(), result&: locs);
234 Block *newBlock =
235 createBlock(insertBefore: block, argTypes: argumentConversion.getConvertedTypes(), locs);
236 replaceAllUsesWith(from: block, to: newBlock);
237
238 // Create necessary casts in new block.
239 SmallVector<Value> castResults;
240 for (auto [i, arg] : llvm::enumerate(First: block->getArguments())) {
241 TypeRange convertedTypes = argumentConversion.getConvertedTypes(originalTypeNo: i);
242 ValueRange newArgs =
243 argumentConversion.getConvertedValues(convertedValues: newBlock->getArguments(), originalValueNo: i);
244 if (isIdentityConversion(originalType: arg.getType(), convertedTypes)) {
245 // Identity conversion: take argument as is.
246 assert(newArgs.size() == 1);
247 castResults.push_back(Elt: newArgs.front());
248 } else {
249 // Non-identity conversion: cast the converted arguments to the original
250 // type.
251 PatternRewriter::InsertionGuard g(*this);
252 setInsertionPointToStart(newBlock);
253 ValueRange castResult = buildUnrealizedCast(builder&: *this, resultTypes: arg.getType(), inputs: newArgs,
254 kind: CastKind::Argument);
255 assert(castResult.size() == 1);
256 castResults.push_back(Elt: castResult.front());
257 }
258 }
259
260 // Merge old block into new block such that we only have the latter with the
261 // new signature.
262 mergeBlocks(source: block, dest: newBlock, argValues: castResults);
263
264 return newBlock;
265}
266
267LogicalResult
268OneToNConversionPattern::matchAndRewrite(Operation *op,
269 PatternRewriter &rewriter) const {
270 auto *typeConverter = getTypeConverter<OneToNTypeConverter>();
271
272 // Construct conversion mapping for results.
273 Operation::result_type_range originalResultTypes = op->getResultTypes();
274 OneToNTypeMapping resultMapping(originalResultTypes);
275 if (failed(result: typeConverter->computeTypeMapping(types: originalResultTypes,
276 result&: resultMapping)))
277 return failure();
278
279 // Construct conversion mapping for operands.
280 Operation::operand_type_range originalOperandTypes = op->getOperandTypes();
281 OneToNTypeMapping operandMapping(originalOperandTypes);
282 if (failed(result: typeConverter->computeTypeMapping(types: originalOperandTypes,
283 result&: operandMapping)))
284 return failure();
285
286 // Cast operands to target types.
287 SmallVector<Value> convertedOperands = buildUnrealizedForwardCasts(
288 originalValues: op->getOperands(), conversion&: operandMapping, rewriter, kind: CastKind::Target);
289
290 // Create a `OneToNPatternRewriter` for the pattern, which provides additional
291 // functionality.
292 // TODO(ingomueller): I guess it would be better to use only one rewriter
293 // throughout the whole pass, but that would require to
294 // drive the pattern application ourselves, which is a lot
295 // of additional boilerplate code. This seems to work fine,
296 // so I leave it like this for the time being.
297 OneToNPatternRewriter oneToNPatternRewriter(rewriter.getContext(),
298 rewriter.getListener());
299 oneToNPatternRewriter.restoreInsertionPoint(ip: rewriter.saveInsertionPoint());
300
301 // Apply actual pattern.
302 if (failed(result: matchAndRewrite(op, rewriter&: oneToNPatternRewriter, operandMapping,
303 resultMapping, convertedOperands)))
304 return failure();
305
306 return success();
307}
308
309namespace mlir {
310
311// This function applies the provided patterns using
312// `applyPatternsAndFoldGreedily` and then replaces all newly inserted
313// `UnrealizedConversionCastOps` that haven't folded away. ("Backward" casts
314// from target to source types inserted by a `OneToNConversionPattern` normally
315// fold away with the "forward" casts from source to target types inserted by
316// the next pattern.) To understand which casts are "newly inserted", all casts
317// inserted by this pass are annotated with a string attribute that also
318// documents which kind of the cast (source, argument, or target).
319LogicalResult
320applyPartialOneToNConversion(Operation *op, OneToNTypeConverter &typeConverter,
321 const FrozenRewritePatternSet &patterns) {
322#ifndef NDEBUG
323 // Remember existing unrealized casts. This data structure is only used in
324 // asserts; building it only for that purpose may be an overkill.
325 SmallSet<UnrealizedConversionCastOp, 4> existingCasts;
326 op->walk(callback: [&](UnrealizedConversionCastOp castOp) {
327 assert(!castOp->hasAttr(castKindAttrName));
328 existingCasts.insert(castOp);
329 });
330#endif // NDEBUG
331
332 // Apply provided conversion patterns.
333 if (failed(result: applyPatternsAndFoldGreedily(op, patterns))) {
334 emitError(loc: op->getLoc()) << "failed to apply conversion patterns";
335 return failure();
336 }
337
338 // Find all unrealized casts inserted by the pass that haven't folded away.
339 SmallVector<UnrealizedConversionCastOp> worklist;
340 op->walk(callback: [&](UnrealizedConversionCastOp castOp) {
341 if (castOp->hasAttr(castKindAttrName)) {
342 assert(!existingCasts.contains(castOp));
343 worklist.push_back(castOp);
344 }
345 });
346
347 // Replace new casts with user materializations.
348 IRRewriter rewriter(op->getContext());
349 for (UnrealizedConversionCastOp castOp : worklist) {
350 TypeRange resultTypes = castOp->getResultTypes();
351 ValueRange operands = castOp->getOperands();
352 StringRef castKind =
353 castOp->getAttrOfType<StringAttr>(castKindAttrName).getValue();
354 rewriter.setInsertionPoint(castOp);
355
356#ifndef NDEBUG
357 // Determine whether operands or results are already legal to test some
358 // assumptions for the different kind of materializations. These properties
359 // are only used it asserts and it may be overkill to compute them.
360 bool areOperandTypesLegal = llvm::all_of(
361 operands.getTypes(), [&](Type t) { return typeConverter.isLegal(t); });
362 bool areResultsTypesLegal = llvm::all_of(
363 resultTypes, [&](Type t) { return typeConverter.isLegal(t); });
364#endif // NDEBUG
365
366 // Add materialization and remember materialized results.
367 SmallVector<Value> materializedResults;
368 if (castKind == getCastKindName(CastKind::Target)) {
369 // Target materialization.
370 assert(!areOperandTypesLegal && areResultsTypesLegal &&
371 operands.size() == 1 && "found unexpected target cast");
372 std::optional<SmallVector<Value>> maybeResults =
373 typeConverter.materializeTargetConversion(
374 rewriter, castOp->getLoc(), resultTypes, operands.front());
375 if (!maybeResults) {
376 emitError(castOp->getLoc())
377 << "failed to create target materialization";
378 return failure();
379 }
380 materializedResults = maybeResults.value();
381 } else {
382 // Source and argument materializations.
383 assert(areOperandTypesLegal && !areResultsTypesLegal &&
384 resultTypes.size() == 1 && "found unexpected cast");
385 std::optional<Value> maybeResult;
386 if (castKind == getCastKindName(CastKind::Source)) {
387 // Source materialization.
388 maybeResult = typeConverter.materializeSourceConversion(
389 rewriter, castOp->getLoc(), resultTypes.front(),
390 castOp.getOperands());
391 } else {
392 // Argument materialization.
393 assert(castKind == getCastKindName(CastKind::Argument) &&
394 "unexpected value of cast kind attribute");
395 assert(llvm::all_of(operands, llvm::IsaPred<BlockArgument>));
396 maybeResult = typeConverter.materializeArgumentConversion(
397 rewriter, castOp->getLoc(), resultTypes.front(),
398 castOp.getOperands());
399 }
400 if (!maybeResult.has_value() || !maybeResult.value()) {
401 emitError(castOp->getLoc())
402 << "failed to create " << castKind << " materialization";
403 return failure();
404 }
405 materializedResults = {maybeResult.value()};
406 }
407
408 // Replace the cast with the result of the materialization.
409 rewriter.replaceOp(castOp, materializedResults);
410 }
411
412 return success();
413}
414
415} // namespace mlir
416

source code of mlir/lib/Transforms/Utils/OneToNTypeConversion.cpp