1 | //===-- OneToNTypeConversion.cpp - Utils for 1:N type conversion-*- C++ -*-===// |
2 | // |
3 | // Licensed under the Apache License v2.0 with LLVM Exceptions. |
4 | // See https://llvm.org/LICENSE.txt for license information. |
5 | // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
6 | // |
7 | //===----------------------------------------------------------------------===// |
8 | |
9 | #include "mlir/Transforms/OneToNTypeConversion.h" |
10 | |
11 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
12 | #include "llvm/ADT/SmallSet.h" |
13 | |
14 | #include <unordered_map> |
15 | |
16 | using namespace llvm; |
17 | using namespace mlir; |
18 | |
19 | std::optional<SmallVector<Value>> |
20 | OneToNTypeConverter::materializeTargetConversion(OpBuilder &builder, |
21 | Location loc, |
22 | TypeRange resultTypes, |
23 | Value input) const { |
24 | for (const OneToNMaterializationCallbackFn &fn : |
25 | llvm::reverse(C: oneToNTargetMaterializations)) { |
26 | if (std::optional<SmallVector<Value>> result = |
27 | fn(builder, resultTypes, input, loc)) |
28 | return *result; |
29 | } |
30 | return std::nullopt; |
31 | } |
32 | |
33 | TypeRange OneToNTypeMapping::getConvertedTypes(unsigned originalTypeNo) const { |
34 | TypeRange convertedTypes = getConvertedTypes(); |
35 | if (auto mapping = getInputMapping(input: originalTypeNo)) |
36 | return convertedTypes.slice(n: mapping->inputNo, m: mapping->size); |
37 | return {}; |
38 | } |
39 | |
40 | ValueRange |
41 | OneToNTypeMapping::getConvertedValues(ValueRange convertedValues, |
42 | unsigned originalValueNo) const { |
43 | if (auto mapping = getInputMapping(input: originalValueNo)) |
44 | return convertedValues.slice(n: mapping->inputNo, m: mapping->size); |
45 | return {}; |
46 | } |
47 | |
48 | void OneToNTypeMapping::convertLocation( |
49 | Value originalValue, unsigned originalValueNo, |
50 | llvm::SmallVectorImpl<Location> &result) const { |
51 | if (auto mapping = getInputMapping(input: originalValueNo)) |
52 | result.append(NumInputs: mapping->size, Elt: originalValue.getLoc()); |
53 | } |
54 | |
55 | void OneToNTypeMapping::convertLocations( |
56 | ValueRange originalValues, llvm::SmallVectorImpl<Location> &result) const { |
57 | assert(originalValues.size() == getOriginalTypes().size()); |
58 | for (auto [i, value] : llvm::enumerate(First&: originalValues)) |
59 | convertLocation(originalValue: value, originalValueNo: i, result); |
60 | } |
61 | |
62 | static bool isIdentityConversion(Type originalType, TypeRange convertedTypes) { |
63 | return convertedTypes.size() == 1 && convertedTypes[0] == originalType; |
64 | } |
65 | |
66 | bool OneToNTypeMapping::hasNonIdentityConversion() const { |
67 | // XXX: I think that the original types and the converted types are the same |
68 | // iff there was no non-identity type conversion. If that is true, the |
69 | // patterns could actually test whether there is anything useful to do |
70 | // without having access to the signature conversion. |
71 | for (auto [i, originalType] : llvm::enumerate(First: originalTypes)) { |
72 | TypeRange types = getConvertedTypes(originalTypeNo: i); |
73 | if (!isIdentityConversion(originalType, convertedTypes: types)) { |
74 | assert(TypeRange(originalTypes) != getConvertedTypes()); |
75 | return true; |
76 | } |
77 | } |
78 | assert(TypeRange(originalTypes) == getConvertedTypes()); |
79 | return false; |
80 | } |
81 | |
82 | namespace { |
83 | enum class CastKind { |
84 | // Casts block arguments in the target type back to the source type. (If |
85 | // necessary, this cast becomes an argument materialization.) |
86 | Argument, |
87 | |
88 | // Casts other values in the target type back to the source type. (If |
89 | // necessary, this cast becomes a source materialization.) |
90 | Source, |
91 | |
92 | // Casts values in the source type to the target type. (If necessary, this |
93 | // cast becomes a target materialization.) |
94 | Target |
95 | }; |
96 | } // namespace |
97 | |
98 | /// Mapping of enum values to string values. |
99 | StringRef getCastKindName(CastKind kind) { |
100 | static const std::unordered_map<CastKind, StringRef> castKindNames = { |
101 | {CastKind::Argument, "argument" }, |
102 | {CastKind::Source, "source" }, |
103 | {CastKind::Target, "target" }}; |
104 | return castKindNames.at(k: kind); |
105 | } |
106 | |
107 | /// Attribute name that is used to annotate inserted unrealized casts with their |
108 | /// kind (source, argument, or target). |
109 | static const char *const castKindAttrName = |
110 | "__one-to-n-type-conversion_cast-kind__" ; |
111 | |
112 | /// Builds an `UnrealizedConversionCastOp` from the given inputs to the given |
113 | /// result types. Returns the result values of the cast. |
114 | static ValueRange buildUnrealizedCast(OpBuilder &builder, TypeRange resultTypes, |
115 | ValueRange inputs, CastKind kind) { |
116 | // Special case: 1-to-N conversion with N = 0. No need to build an |
117 | // UnrealizedConversionCastOp because the op will always be dead. |
118 | if (resultTypes.empty()) |
119 | return ValueRange(); |
120 | |
121 | // Create cast. |
122 | Location loc = builder.getUnknownLoc(); |
123 | if (!inputs.empty()) |
124 | loc = inputs.front().getLoc(); |
125 | auto castOp = |
126 | builder.create<UnrealizedConversionCastOp>(loc, resultTypes, inputs); |
127 | |
128 | // Store cast kind as attribute. |
129 | auto kindAttr = StringAttr::get(builder.getContext(), getCastKindName(kind)); |
130 | castOp->setAttr(castKindAttrName, kindAttr); |
131 | |
132 | return castOp->getResults(); |
133 | } |
134 | |
135 | /// Builds one `UnrealizedConversionCastOp` for each of the given original |
136 | /// values using the respective target types given in the provided conversion |
137 | /// mapping and returns the results of these casts. If the conversion mapping of |
138 | /// a value maps a type to itself (i.e., is an identity conversion), then no |
139 | /// cast is inserted and the original value is returned instead. |
140 | /// Note that these unrealized casts are different from target materializations |
141 | /// in that they are *always* inserted, even if they immediately fold away, such |
142 | /// that patterns always see valid intermediate IR, whereas materializations are |
143 | /// only used in the places where the unrealized casts *don't* fold away. |
144 | static SmallVector<Value> |
145 | buildUnrealizedForwardCasts(ValueRange originalValues, |
146 | OneToNTypeMapping &conversion, |
147 | RewriterBase &rewriter, CastKind kind) { |
148 | |
149 | // Convert each operand one by one. |
150 | SmallVector<Value> convertedValues; |
151 | convertedValues.reserve(N: conversion.getConvertedTypes().size()); |
152 | for (auto [idx, originalValue] : llvm::enumerate(First&: originalValues)) { |
153 | TypeRange convertedTypes = conversion.getConvertedTypes(originalTypeNo: idx); |
154 | |
155 | // Identity conversion: keep operand as is. |
156 | if (isIdentityConversion(originalType: originalValue.getType(), convertedTypes)) { |
157 | convertedValues.push_back(Elt: originalValue); |
158 | continue; |
159 | } |
160 | |
161 | // Non-identity conversion: materialize target types. |
162 | ValueRange castResult = |
163 | buildUnrealizedCast(builder&: rewriter, resultTypes: convertedTypes, inputs: originalValue, kind); |
164 | convertedValues.append(in_start: castResult.begin(), in_end: castResult.end()); |
165 | } |
166 | |
167 | return convertedValues; |
168 | } |
169 | |
170 | /// Builds one `UnrealizedConversionCastOp` for each sequence of the given |
171 | /// original values to one value of the type they originated from, i.e., a |
172 | /// "reverse" conversion from N converted values back to one value of the |
173 | /// original type, using the given (forward) type conversion. If a given value |
174 | /// was mapped to a value of the same type (i.e., the conversion in the mapping |
175 | /// is an identity conversion), then the "converted" value is returned without |
176 | /// cast. |
177 | /// Note that these unrealized casts are different from source materializations |
178 | /// in that they are *always* inserted, even if they immediately fold away, such |
179 | /// that patterns always see valid intermediate IR, whereas materializations are |
180 | /// only used in the places where the unrealized casts *don't* fold away. |
181 | static SmallVector<Value> |
182 | buildUnrealizedBackwardsCasts(ValueRange convertedValues, |
183 | const OneToNTypeMapping &typeConversion, |
184 | RewriterBase &rewriter) { |
185 | assert(typeConversion.getConvertedTypes() == convertedValues.getTypes()); |
186 | |
187 | // Create unrealized cast op for each converted result of the op. |
188 | SmallVector<Value> recastValues; |
189 | TypeRange originalTypes = typeConversion.getOriginalTypes(); |
190 | recastValues.reserve(N: originalTypes.size()); |
191 | auto convertedValueIt = convertedValues.begin(); |
192 | for (auto [idx, originalType] : llvm::enumerate(First&: originalTypes)) { |
193 | TypeRange convertedTypes = typeConversion.getConvertedTypes(originalTypeNo: idx); |
194 | size_t numConvertedValues = convertedTypes.size(); |
195 | if (isIdentityConversion(originalType, convertedTypes)) { |
196 | // Identity conversion: take result as is. |
197 | recastValues.push_back(Elt: *convertedValueIt); |
198 | } else { |
199 | // Non-identity conversion: cast back to source type. |
200 | ValueRange recastValue = buildUnrealizedCast( |
201 | builder&: rewriter, resultTypes: originalType, |
202 | inputs: ValueRange{convertedValueIt, convertedValueIt + numConvertedValues}, |
203 | kind: CastKind::Source); |
204 | assert(recastValue.size() == 1); |
205 | recastValues.push_back(Elt: recastValue.front()); |
206 | } |
207 | convertedValueIt += numConvertedValues; |
208 | } |
209 | |
210 | return recastValues; |
211 | } |
212 | |
213 | void OneToNPatternRewriter::replaceOp(Operation *op, ValueRange newValues, |
214 | const OneToNTypeMapping &resultMapping) { |
215 | // Create a cast back to the original types and replace the results of the |
216 | // original op with those. |
217 | assert(newValues.size() == resultMapping.getConvertedTypes().size()); |
218 | assert(op->getResultTypes() == resultMapping.getOriginalTypes()); |
219 | PatternRewriter::InsertionGuard g(*this); |
220 | setInsertionPointAfter(op); |
221 | SmallVector<Value> castResults = |
222 | buildUnrealizedBackwardsCasts(convertedValues: newValues, typeConversion: resultMapping, rewriter&: *this); |
223 | replaceOp(op, newValues: castResults); |
224 | } |
225 | |
226 | Block *OneToNPatternRewriter::applySignatureConversion( |
227 | Block *block, OneToNTypeMapping &argumentConversion) { |
228 | PatternRewriter::InsertionGuard g(*this); |
229 | |
230 | // Split the block at the beginning to get a new block to use for the |
231 | // updated signature. |
232 | SmallVector<Location> locs; |
233 | argumentConversion.convertLocations(originalValues: block->getArguments(), result&: locs); |
234 | Block *newBlock = |
235 | createBlock(insertBefore: block, argTypes: argumentConversion.getConvertedTypes(), locs); |
236 | replaceAllUsesWith(from: block, to: newBlock); |
237 | |
238 | // Create necessary casts in new block. |
239 | SmallVector<Value> castResults; |
240 | for (auto [i, arg] : llvm::enumerate(First: block->getArguments())) { |
241 | TypeRange convertedTypes = argumentConversion.getConvertedTypes(originalTypeNo: i); |
242 | ValueRange newArgs = |
243 | argumentConversion.getConvertedValues(convertedValues: newBlock->getArguments(), originalValueNo: i); |
244 | if (isIdentityConversion(originalType: arg.getType(), convertedTypes)) { |
245 | // Identity conversion: take argument as is. |
246 | assert(newArgs.size() == 1); |
247 | castResults.push_back(Elt: newArgs.front()); |
248 | } else { |
249 | // Non-identity conversion: cast the converted arguments to the original |
250 | // type. |
251 | PatternRewriter::InsertionGuard g(*this); |
252 | setInsertionPointToStart(newBlock); |
253 | ValueRange castResult = buildUnrealizedCast(builder&: *this, resultTypes: arg.getType(), inputs: newArgs, |
254 | kind: CastKind::Argument); |
255 | assert(castResult.size() == 1); |
256 | castResults.push_back(Elt: castResult.front()); |
257 | } |
258 | } |
259 | |
260 | // Merge old block into new block such that we only have the latter with the |
261 | // new signature. |
262 | mergeBlocks(source: block, dest: newBlock, argValues: castResults); |
263 | |
264 | return newBlock; |
265 | } |
266 | |
267 | LogicalResult |
268 | OneToNConversionPattern::matchAndRewrite(Operation *op, |
269 | PatternRewriter &rewriter) const { |
270 | auto *typeConverter = getTypeConverter<OneToNTypeConverter>(); |
271 | |
272 | // Construct conversion mapping for results. |
273 | Operation::result_type_range originalResultTypes = op->getResultTypes(); |
274 | OneToNTypeMapping resultMapping(originalResultTypes); |
275 | if (failed(result: typeConverter->computeTypeMapping(types: originalResultTypes, |
276 | result&: resultMapping))) |
277 | return failure(); |
278 | |
279 | // Construct conversion mapping for operands. |
280 | Operation::operand_type_range originalOperandTypes = op->getOperandTypes(); |
281 | OneToNTypeMapping operandMapping(originalOperandTypes); |
282 | if (failed(result: typeConverter->computeTypeMapping(types: originalOperandTypes, |
283 | result&: operandMapping))) |
284 | return failure(); |
285 | |
286 | // Cast operands to target types. |
287 | SmallVector<Value> convertedOperands = buildUnrealizedForwardCasts( |
288 | originalValues: op->getOperands(), conversion&: operandMapping, rewriter, kind: CastKind::Target); |
289 | |
290 | // Create a `OneToNPatternRewriter` for the pattern, which provides additional |
291 | // functionality. |
292 | // TODO(ingomueller): I guess it would be better to use only one rewriter |
293 | // throughout the whole pass, but that would require to |
294 | // drive the pattern application ourselves, which is a lot |
295 | // of additional boilerplate code. This seems to work fine, |
296 | // so I leave it like this for the time being. |
297 | OneToNPatternRewriter oneToNPatternRewriter(rewriter.getContext(), |
298 | rewriter.getListener()); |
299 | oneToNPatternRewriter.restoreInsertionPoint(ip: rewriter.saveInsertionPoint()); |
300 | |
301 | // Apply actual pattern. |
302 | if (failed(result: matchAndRewrite(op, rewriter&: oneToNPatternRewriter, operandMapping, |
303 | resultMapping, convertedOperands))) |
304 | return failure(); |
305 | |
306 | return success(); |
307 | } |
308 | |
309 | namespace mlir { |
310 | |
311 | // This function applies the provided patterns using |
312 | // `applyPatternsAndFoldGreedily` and then replaces all newly inserted |
313 | // `UnrealizedConversionCastOps` that haven't folded away. ("Backward" casts |
314 | // from target to source types inserted by a `OneToNConversionPattern` normally |
315 | // fold away with the "forward" casts from source to target types inserted by |
316 | // the next pattern.) To understand which casts are "newly inserted", all casts |
317 | // inserted by this pass are annotated with a string attribute that also |
318 | // documents which kind of the cast (source, argument, or target). |
319 | LogicalResult |
320 | applyPartialOneToNConversion(Operation *op, OneToNTypeConverter &typeConverter, |
321 | const FrozenRewritePatternSet &patterns) { |
322 | #ifndef NDEBUG |
323 | // Remember existing unrealized casts. This data structure is only used in |
324 | // asserts; building it only for that purpose may be an overkill. |
325 | SmallSet<UnrealizedConversionCastOp, 4> existingCasts; |
326 | op->walk(callback: [&](UnrealizedConversionCastOp castOp) { |
327 | assert(!castOp->hasAttr(castKindAttrName)); |
328 | existingCasts.insert(castOp); |
329 | }); |
330 | #endif // NDEBUG |
331 | |
332 | // Apply provided conversion patterns. |
333 | if (failed(result: applyPatternsAndFoldGreedily(op, patterns))) { |
334 | emitError(loc: op->getLoc()) << "failed to apply conversion patterns" ; |
335 | return failure(); |
336 | } |
337 | |
338 | // Find all unrealized casts inserted by the pass that haven't folded away. |
339 | SmallVector<UnrealizedConversionCastOp> worklist; |
340 | op->walk(callback: [&](UnrealizedConversionCastOp castOp) { |
341 | if (castOp->hasAttr(castKindAttrName)) { |
342 | assert(!existingCasts.contains(castOp)); |
343 | worklist.push_back(castOp); |
344 | } |
345 | }); |
346 | |
347 | // Replace new casts with user materializations. |
348 | IRRewriter rewriter(op->getContext()); |
349 | for (UnrealizedConversionCastOp castOp : worklist) { |
350 | TypeRange resultTypes = castOp->getResultTypes(); |
351 | ValueRange operands = castOp->getOperands(); |
352 | StringRef castKind = |
353 | castOp->getAttrOfType<StringAttr>(castKindAttrName).getValue(); |
354 | rewriter.setInsertionPoint(castOp); |
355 | |
356 | #ifndef NDEBUG |
357 | // Determine whether operands or results are already legal to test some |
358 | // assumptions for the different kind of materializations. These properties |
359 | // are only used it asserts and it may be overkill to compute them. |
360 | bool areOperandTypesLegal = llvm::all_of( |
361 | operands.getTypes(), [&](Type t) { return typeConverter.isLegal(t); }); |
362 | bool areResultsTypesLegal = llvm::all_of( |
363 | resultTypes, [&](Type t) { return typeConverter.isLegal(t); }); |
364 | #endif // NDEBUG |
365 | |
366 | // Add materialization and remember materialized results. |
367 | SmallVector<Value> materializedResults; |
368 | if (castKind == getCastKindName(CastKind::Target)) { |
369 | // Target materialization. |
370 | assert(!areOperandTypesLegal && areResultsTypesLegal && |
371 | operands.size() == 1 && "found unexpected target cast" ); |
372 | std::optional<SmallVector<Value>> maybeResults = |
373 | typeConverter.materializeTargetConversion( |
374 | rewriter, castOp->getLoc(), resultTypes, operands.front()); |
375 | if (!maybeResults) { |
376 | emitError(castOp->getLoc()) |
377 | << "failed to create target materialization" ; |
378 | return failure(); |
379 | } |
380 | materializedResults = maybeResults.value(); |
381 | } else { |
382 | // Source and argument materializations. |
383 | assert(areOperandTypesLegal && !areResultsTypesLegal && |
384 | resultTypes.size() == 1 && "found unexpected cast" ); |
385 | std::optional<Value> maybeResult; |
386 | if (castKind == getCastKindName(CastKind::Source)) { |
387 | // Source materialization. |
388 | maybeResult = typeConverter.materializeSourceConversion( |
389 | rewriter, castOp->getLoc(), resultTypes.front(), |
390 | castOp.getOperands()); |
391 | } else { |
392 | // Argument materialization. |
393 | assert(castKind == getCastKindName(CastKind::Argument) && |
394 | "unexpected value of cast kind attribute" ); |
395 | assert(llvm::all_of(operands, llvm::IsaPred<BlockArgument>)); |
396 | maybeResult = typeConverter.materializeArgumentConversion( |
397 | rewriter, castOp->getLoc(), resultTypes.front(), |
398 | castOp.getOperands()); |
399 | } |
400 | if (!maybeResult.has_value() || !maybeResult.value()) { |
401 | emitError(castOp->getLoc()) |
402 | << "failed to create " << castKind << " materialization" ; |
403 | return failure(); |
404 | } |
405 | materializedResults = {maybeResult.value()}; |
406 | } |
407 | |
408 | // Replace the cast with the result of the materialization. |
409 | rewriter.replaceOp(castOp, materializedResults); |
410 | } |
411 | |
412 | return success(); |
413 | } |
414 | |
415 | } // namespace mlir |
416 | |