1/*
2 * Copyright 2020 Google LLC
3 *
4 * Use of this source code is governed by a BSD-style license that can be
5 * found in the LICENSE file.
6 */
7
8#include "src/sksl/SkSLConstantFolder.h"
9
10#include "include/core/SkTypes.h"
11#include "include/private/base/SkFloatingPoint.h"
12#include "include/private/base/SkTArray.h"
13#include "src/sksl/SkSLAnalysis.h"
14#include "src/sksl/SkSLContext.h"
15#include "src/sksl/SkSLErrorReporter.h"
16#include "src/sksl/SkSLPosition.h"
17#include "src/sksl/SkSLProgramSettings.h"
18#include "src/sksl/ir/SkSLBinaryExpression.h"
19#include "src/sksl/ir/SkSLConstructorCompound.h"
20#include "src/sksl/ir/SkSLConstructorDiagonalMatrix.h"
21#include "src/sksl/ir/SkSLConstructorSplat.h"
22#include "src/sksl/ir/SkSLExpression.h"
23#include "src/sksl/ir/SkSLLiteral.h"
24#include "src/sksl/ir/SkSLModifierFlags.h"
25#include "src/sksl/ir/SkSLPrefixExpression.h"
26#include "src/sksl/ir/SkSLType.h"
27#include "src/sksl/ir/SkSLVariable.h"
28#include "src/sksl/ir/SkSLVariableReference.h"
29
30#include <cstdint>
31#include <float.h>
32#include <limits>
33#include <optional>
34#include <string>
35#include <utility>
36
37using namespace skia_private;
38
39namespace SkSL {
40
41static bool is_vec_or_mat(const Type& type) {
42 switch (type.typeKind()) {
43 case Type::TypeKind::kMatrix:
44 case Type::TypeKind::kVector:
45 return true;
46
47 default:
48 return false;
49 }
50}
51
52static std::unique_ptr<Expression> eliminate_no_op_boolean(Position pos,
53 const Expression& left,
54 Operator op,
55 const Expression& right) {
56 bool rightVal = right.as<Literal>().boolValue();
57
58 // Detect no-op Boolean expressions and optimize them away.
59 if ((op.kind() == Operator::Kind::LOGICALAND && rightVal) || // (expr && true) -> (expr)
60 (op.kind() == Operator::Kind::LOGICALOR && !rightVal) || // (expr || false) -> (expr)
61 (op.kind() == Operator::Kind::LOGICALXOR && !rightVal) || // (expr ^^ false) -> (expr)
62 (op.kind() == Operator::Kind::EQEQ && rightVal) || // (expr == true) -> (expr)
63 (op.kind() == Operator::Kind::NEQ && !rightVal)) { // (expr != false) -> (expr)
64
65 return left.clone(pos);
66 }
67
68 return nullptr;
69}
70
71static std::unique_ptr<Expression> short_circuit_boolean(Position pos,
72 const Expression& left,
73 Operator op,
74 const Expression& right) {
75 bool leftVal = left.as<Literal>().boolValue();
76
77 // When the literal is on the left, we can sometimes eliminate the other expression entirely.
78 if ((op.kind() == Operator::Kind::LOGICALAND && !leftVal) || // (false && expr) -> (false)
79 (op.kind() == Operator::Kind::LOGICALOR && leftVal)) { // (true || expr) -> (true)
80
81 return left.clone(pos);
82 }
83
84 // We can't eliminate the right-side expression via short-circuit, but we might still be able to
85 // simplify away a no-op expression.
86 return eliminate_no_op_boolean(pos, left: right, op, right: left);
87}
88
89static std::unique_ptr<Expression> simplify_constant_equality(const Context& context,
90 Position pos,
91 const Expression& left,
92 Operator op,
93 const Expression& right) {
94 if (op.kind() == Operator::Kind::EQEQ || op.kind() == Operator::Kind::NEQ) {
95 bool equality = (op.kind() == Operator::Kind::EQEQ);
96
97 switch (left.compareConstant(other: right)) {
98 case Expression::ComparisonResult::kNotEqual:
99 equality = !equality;
100 [[fallthrough]];
101
102 case Expression::ComparisonResult::kEqual:
103 return Literal::MakeBool(context, pos, value: equality);
104
105 case Expression::ComparisonResult::kUnknown:
106 break;
107 }
108 }
109 return nullptr;
110}
111
112static std::unique_ptr<Expression> simplify_matrix_multiplication(const Context& context,
113 Position pos,
114 const Expression& left,
115 const Expression& right,
116 int leftColumns,
117 int leftRows,
118 int rightColumns,
119 int rightRows) {
120 const Type& componentType = left.type().componentType();
121 SkASSERT(componentType.matches(right.type().componentType()));
122
123 // Fetch the left matrix.
124 double leftVals[4][4];
125 for (int c = 0; c < leftColumns; ++c) {
126 for (int r = 0; r < leftRows; ++r) {
127 leftVals[c][r] = *left.getConstantValue(n: (c * leftRows) + r);
128 }
129 }
130 // Fetch the right matrix.
131 double rightVals[4][4];
132 for (int c = 0; c < rightColumns; ++c) {
133 for (int r = 0; r < rightRows; ++r) {
134 rightVals[c][r] = *right.getConstantValue(n: (c * rightRows) + r);
135 }
136 }
137
138 SkASSERT(leftColumns == rightRows);
139 int outColumns = rightColumns,
140 outRows = leftRows;
141
142 double args[16];
143 int argIndex = 0;
144 for (int c = 0; c < outColumns; ++c) {
145 for (int r = 0; r < outRows; ++r) {
146 // Compute a dot product for this position.
147 double val = 0;
148 for (int dotIdx = 0; dotIdx < leftColumns; ++dotIdx) {
149 val += leftVals[dotIdx][r] * rightVals[c][dotIdx];
150 }
151
152 if (val >= -FLT_MAX && val <= FLT_MAX) {
153 args[argIndex++] = val;
154 } else {
155 // The value is outside the 32-bit float range, or is NaN; do not optimize.
156 return nullptr;
157 }
158 }
159 }
160
161 if (outColumns == 1) {
162 // Matrix-times-vector conceptually makes a 1-column N-row matrix, but we return vecN.
163 std::swap(x&: outColumns, y&: outRows);
164 }
165
166 const Type& resultType = componentType.toCompound(context, columns: outColumns, rows: outRows);
167 return ConstructorCompound::MakeFromConstants(context, pos, type: resultType, values: args);
168}
169
170static std::unique_ptr<Expression> simplify_matrix_times_matrix(const Context& context,
171 Position pos,
172 const Expression& left,
173 const Expression& right) {
174 const Type& leftType = left.type();
175 const Type& rightType = right.type();
176
177 SkASSERT(leftType.isMatrix());
178 SkASSERT(rightType.isMatrix());
179
180 return simplify_matrix_multiplication(context, pos, left, right,
181 leftColumns: leftType.columns(), leftRows: leftType.rows(),
182 rightColumns: rightType.columns(), rightRows: rightType.rows());
183}
184
185static std::unique_ptr<Expression> simplify_vector_times_matrix(const Context& context,
186 Position pos,
187 const Expression& left,
188 const Expression& right) {
189 const Type& leftType = left.type();
190 const Type& rightType = right.type();
191
192 SkASSERT(leftType.isVector());
193 SkASSERT(rightType.isMatrix());
194
195 return simplify_matrix_multiplication(context, pos, left, right,
196 /*leftColumns=*/leftType.columns(), /*leftRows=*/1,
197 rightColumns: rightType.columns(), rightRows: rightType.rows());
198}
199
200static std::unique_ptr<Expression> simplify_matrix_times_vector(const Context& context,
201 Position pos,
202 const Expression& left,
203 const Expression& right) {
204 const Type& leftType = left.type();
205 const Type& rightType = right.type();
206
207 SkASSERT(leftType.isMatrix());
208 SkASSERT(rightType.isVector());
209
210 return simplify_matrix_multiplication(context, pos, left, right,
211 leftColumns: leftType.columns(), leftRows: leftType.rows(),
212 /*rightColumns=*/1, /*rightRows=*/rightType.columns());
213}
214
215static std::unique_ptr<Expression> simplify_componentwise(const Context& context,
216 Position pos,
217 const Expression& left,
218 Operator op,
219 const Expression& right) {
220 SkASSERT(is_vec_or_mat(left.type()));
221 SkASSERT(left.type().matches(right.type()));
222 const Type& type = left.type();
223
224 // Handle equality operations: == !=
225 if (std::unique_ptr<Expression> result = simplify_constant_equality(context, pos, left, op,
226 right)) {
227 return result;
228 }
229
230 // Handle floating-point arithmetic: + - * /
231 using FoldFn = double (*)(double, double);
232 FoldFn foldFn;
233 switch (op.kind()) {
234 case Operator::Kind::PLUS: foldFn = +[](double a, double b) { return a + b; }; break;
235 case Operator::Kind::MINUS: foldFn = +[](double a, double b) { return a - b; }; break;
236 case Operator::Kind::STAR: foldFn = +[](double a, double b) { return a * b; }; break;
237 case Operator::Kind::SLASH: foldFn = +[](double a, double b) { return a / b; }; break;
238 default:
239 return nullptr;
240 }
241
242 const Type& componentType = type.componentType();
243 SkASSERT(componentType.isNumber());
244
245 double minimumValue = componentType.minimumValue();
246 double maximumValue = componentType.maximumValue();
247
248 double args[16];
249 int numSlots = type.slotCount();
250 for (int i = 0; i < numSlots; i++) {
251 double value = foldFn(*left.getConstantValue(n: i), *right.getConstantValue(n: i));
252 if (value < minimumValue || value > maximumValue) {
253 return nullptr;
254 }
255 args[i] = value;
256 }
257 return ConstructorCompound::MakeFromConstants(context, pos, type, values: args);
258}
259
260static std::unique_ptr<Expression> splat_scalar(const Context& context,
261 const Expression& scalar,
262 const Type& type) {
263 if (type.isVector()) {
264 return ConstructorSplat::Make(context, pos: scalar.fPosition, type, arg: scalar.clone());
265 }
266 if (type.isMatrix()) {
267 int numSlots = type.slotCount();
268 ExpressionArray splatMatrix;
269 splatMatrix.reserve_exact(n: numSlots);
270 for (int index = 0; index < numSlots; ++index) {
271 splatMatrix.push_back(t: scalar.clone());
272 }
273 return ConstructorCompound::Make(context, pos: scalar.fPosition, type, args: std::move(splatMatrix));
274 }
275 SkDEBUGFAILF("unsupported type %s", type.description().c_str());
276 return nullptr;
277}
278
279static std::unique_ptr<Expression> cast_expression(const Context& context,
280 Position pos,
281 const Expression& expr,
282 const Type& type) {
283 SkASSERT(type.componentType().matches(expr.type().componentType()));
284 if (expr.type().isScalar()) {
285 if (type.isMatrix()) {
286 return ConstructorDiagonalMatrix::Make(context, pos, type, arg: expr.clone());
287 }
288 if (type.isVector()) {
289 return ConstructorSplat::Make(context, pos, type, arg: expr.clone());
290 }
291 }
292 if (type.matches(other: expr.type())) {
293 return expr.clone(pos);
294 }
295 // We can't cast matrices into vectors or vice-versa.
296 return nullptr;
297}
298
299static std::unique_ptr<Expression> zero_expression(const Context& context,
300 Position pos,
301 const Type& type) {
302 std::unique_ptr<Expression> zero = Literal::Make(pos, value: 0.0, type: &type.componentType());
303 if (type.isScalar()) {
304 return zero;
305 }
306 if (type.isVector()) {
307 return ConstructorSplat::Make(context, pos, type, arg: std::move(zero));
308 }
309 if (type.isMatrix()) {
310 return ConstructorDiagonalMatrix::Make(context, pos, type, arg: std::move(zero));
311 }
312 SkDEBUGFAILF("unsupported type %s", type.description().c_str());
313 return nullptr;
314}
315
316static std::unique_ptr<Expression> negate_expression(const Context& context,
317 Position pos,
318 const Expression& expr,
319 const Type& type) {
320 std::unique_ptr<Expression> ctor = cast_expression(context, pos, expr, type);
321 return ctor ? PrefixExpression::Make(context, pos, op: Operator::Kind::MINUS, base: std::move(ctor))
322 : nullptr;
323}
324
325bool ConstantFolder::GetConstantInt(const Expression& value, SKSL_INT* out) {
326 const Expression* expr = GetConstantValueForVariable(value);
327 if (!expr->isIntLiteral()) {
328 return false;
329 }
330 *out = expr->as<Literal>().intValue();
331 return true;
332}
333
334bool ConstantFolder::GetConstantValue(const Expression& value, double* out) {
335 const Expression* expr = GetConstantValueForVariable(value);
336 if (!expr->is<Literal>()) {
337 return false;
338 }
339 *out = expr->as<Literal>().value();
340 return true;
341}
342
343static bool contains_constant_zero(const Expression& expr) {
344 int numSlots = expr.type().slotCount();
345 for (int index = 0; index < numSlots; ++index) {
346 std::optional<double> slotVal = expr.getConstantValue(n: index);
347 if (slotVal.has_value() && *slotVal == 0.0) {
348 return true;
349 }
350 }
351 return false;
352}
353
354bool ConstantFolder::IsConstantSplat(const Expression& expr, double value) {
355 int numSlots = expr.type().slotCount();
356 for (int index = 0; index < numSlots; ++index) {
357 std::optional<double> slotVal = expr.getConstantValue(n: index);
358 if (!slotVal.has_value() || *slotVal != value) {
359 return false;
360 }
361 }
362 return true;
363}
364
365// Returns true if the expression is a square diagonal matrix containing `value`.
366static bool is_constant_diagonal(const Expression& expr, double value) {
367 SkASSERT(expr.type().isMatrix());
368 int columns = expr.type().columns();
369 int rows = expr.type().rows();
370 if (columns != rows) {
371 return false;
372 }
373 int slotIdx = 0;
374 for (int c = 0; c < columns; ++c) {
375 for (int r = 0; r < rows; ++r) {
376 double expectation = (c == r) ? value : 0;
377 std::optional<double> slotVal = expr.getConstantValue(n: slotIdx++);
378 if (!slotVal.has_value() || *slotVal != expectation) {
379 return false;
380 }
381 }
382 }
383 return true;
384}
385
386// Returns true if the expression is a scalar, vector, or diagonal matrix containing `value`.
387static bool is_constant_value(const Expression& expr, double value) {
388 return expr.type().isMatrix() ? is_constant_diagonal(expr, value)
389 : ConstantFolder::IsConstantSplat(expr, value);
390}
391
392// The expression represents the right-hand side of a division op. If the division can be
393// strength-reduced into multiplication by a reciprocal, returns that reciprocal as an expression.
394// Note that this only supports literal values with safe-to-use reciprocals, and returns null if
395// Expression contains anything else.
396static std::unique_ptr<Expression> make_reciprocal_expression(const Context& context,
397 const Expression& right) {
398 if (right.type().isMatrix() || !right.type().componentType().isFloat()) {
399 return nullptr;
400 }
401 // Verify that each slot contains a finite, non-zero literal, take its reciprocal.
402 double values[4];
403 int nslots = right.type().slotCount();
404 for (int index = 0; index < nslots; ++index) {
405 std::optional<double> value = right.getConstantValue(n: index);
406 if (!value) {
407 return nullptr;
408 }
409 *value = sk_ieee_double_divide(numer: 1.0, denom: *value);
410 if (*value >= -FLT_MAX && *value <= FLT_MAX && *value != 0.0) {
411 // The reciprocal can be represented safely as a finite 32-bit float.
412 values[index] = *value;
413 } else {
414 // The value is outside the 32-bit float range, or is NaN; do not optimize.
415 return nullptr;
416 }
417 }
418 // Turn the expression array into a compound constructor. (If this is a single-slot expression,
419 // this will return the literal as-is.)
420 return ConstructorCompound::MakeFromConstants(context, pos: right.fPosition, type: right.type(), values);
421}
422
423static bool error_on_divide_by_zero(const Context& context, Position pos, Operator op,
424 const Expression& right) {
425 switch (op.kind()) {
426 case Operator::Kind::SLASH:
427 case Operator::Kind::SLASHEQ:
428 case Operator::Kind::PERCENT:
429 case Operator::Kind::PERCENTEQ:
430 if (contains_constant_zero(expr: right)) {
431 context.fErrors->error(position: pos, msg: "division by zero");
432 return true;
433 }
434 return false;
435 default:
436 return false;
437 }
438}
439
440const Expression* ConstantFolder::GetConstantValueOrNull(const Expression& inExpr) {
441 const Expression* expr = &inExpr;
442 while (expr->is<VariableReference>()) {
443 const VariableReference& varRef = expr->as<VariableReference>();
444 if (varRef.refKind() != VariableRefKind::kRead) {
445 return nullptr;
446 }
447 const Variable& var = *varRef.variable();
448 if (!var.modifierFlags().isConst()) {
449 return nullptr;
450 }
451 expr = var.initialValue();
452 if (!expr) {
453 // Generally, const variables must have initial values. However, function parameters are
454 // an exception; they can be const but won't have an initial value.
455 return nullptr;
456 }
457 }
458 return Analysis::IsCompileTimeConstant(expr: *expr) ? expr : nullptr;
459}
460
461const Expression* ConstantFolder::GetConstantValueForVariable(const Expression& inExpr) {
462 const Expression* expr = GetConstantValueOrNull(inExpr);
463 return expr ? expr : &inExpr;
464}
465
466std::unique_ptr<Expression> ConstantFolder::MakeConstantValueForVariable(
467 Position pos, std::unique_ptr<Expression> inExpr) {
468 const Expression* expr = GetConstantValueOrNull(inExpr: *inExpr);
469 return expr ? expr->clone(pos) : std::move(inExpr);
470}
471
472static bool is_scalar_op_matrix(const Expression& left, const Expression& right) {
473 return left.type().isScalar() && right.type().isMatrix();
474}
475
476static bool is_matrix_op_scalar(const Expression& left, const Expression& right) {
477 return is_scalar_op_matrix(left: right, right: left);
478}
479
480static std::unique_ptr<Expression> simplify_arithmetic(const Context& context,
481 Position pos,
482 const Expression& left,
483 Operator op,
484 const Expression& right,
485 const Type& resultType) {
486 switch (op.kind()) {
487 case Operator::Kind::PLUS:
488 if (!is_scalar_op_matrix(left, right) &&
489 ConstantFolder::IsConstantSplat(expr: right, value: 0.0)) { // x + 0
490 if (std::unique_ptr<Expression> expr = cast_expression(context, pos, expr: left,
491 type: resultType)) {
492 return expr;
493 }
494 }
495 if (!is_matrix_op_scalar(left, right) &&
496 ConstantFolder::IsConstantSplat(expr: left, value: 0.0)) { // 0 + x
497 if (std::unique_ptr<Expression> expr = cast_expression(context, pos, expr: right,
498 type: resultType)) {
499 return expr;
500 }
501 }
502 break;
503
504 case Operator::Kind::STAR:
505 if (is_constant_value(expr: right, value: 1.0)) { // x * 1
506 if (std::unique_ptr<Expression> expr = cast_expression(context, pos, expr: left,
507 type: resultType)) {
508 return expr;
509 }
510 }
511 if (is_constant_value(expr: left, value: 1.0)) { // 1 * x
512 if (std::unique_ptr<Expression> expr = cast_expression(context, pos, expr: right,
513 type: resultType)) {
514 return expr;
515 }
516 }
517 if (is_constant_value(expr: right, value: 0.0) && !Analysis::HasSideEffects(expr: left)) { // x * 0
518 return zero_expression(context, pos, type: resultType);
519 }
520 if (is_constant_value(expr: left, value: 0.0) && !Analysis::HasSideEffects(expr: right)) { // 0 * x
521 return zero_expression(context, pos, type: resultType);
522 }
523 if (is_constant_value(expr: right, value: -1.0)) { // x * -1 (to `-x`)
524 if (std::unique_ptr<Expression> expr = negate_expression(context, pos, expr: left,
525 type: resultType)) {
526 return expr;
527 }
528 }
529 if (is_constant_value(expr: left, value: -1.0)) { // -1 * x (to `-x`)
530 if (std::unique_ptr<Expression> expr = negate_expression(context, pos, expr: right,
531 type: resultType)) {
532 return expr;
533 }
534 }
535 break;
536
537 case Operator::Kind::MINUS:
538 if (!is_scalar_op_matrix(left, right) &&
539 ConstantFolder::IsConstantSplat(expr: right, value: 0.0)) { // x - 0
540 if (std::unique_ptr<Expression> expr = cast_expression(context, pos, expr: left,
541 type: resultType)) {
542 return expr;
543 }
544 }
545 if (!is_matrix_op_scalar(left, right) &&
546 ConstantFolder::IsConstantSplat(expr: left, value: 0.0)) { // 0 - x
547 if (std::unique_ptr<Expression> expr = negate_expression(context, pos, expr: right,
548 type: resultType)) {
549 return expr;
550 }
551 }
552 break;
553
554 case Operator::Kind::SLASH:
555 if (!is_scalar_op_matrix(left, right) &&
556 ConstantFolder::IsConstantSplat(expr: right, value: 1.0)) { // x / 1
557 if (std::unique_ptr<Expression> expr = cast_expression(context, pos, expr: left,
558 type: resultType)) {
559 return expr;
560 }
561 }
562 if (!left.type().isMatrix()) { // convert `x / 2` into `x * 0.5`
563 if (std::unique_ptr<Expression> expr = make_reciprocal_expression(context, right)) {
564 return BinaryExpression::Make(context, pos, left: left.clone(), op: Operator::Kind::STAR,
565 right: std::move(expr));
566 }
567 }
568 break;
569
570 case Operator::Kind::PLUSEQ:
571 case Operator::Kind::MINUSEQ:
572 if (ConstantFolder::IsConstantSplat(expr: right, value: 0.0)) { // x += 0, x -= 0
573 if (std::unique_ptr<Expression> var = cast_expression(context, pos, expr: left,
574 type: resultType)) {
575 Analysis::UpdateVariableRefKind(expr: var.get(), kind: VariableRefKind::kRead);
576 return var;
577 }
578 }
579 break;
580
581 case Operator::Kind::STAREQ:
582 if (is_constant_value(expr: right, value: 1.0)) { // x *= 1
583 if (std::unique_ptr<Expression> var = cast_expression(context, pos, expr: left,
584 type: resultType)) {
585 Analysis::UpdateVariableRefKind(expr: var.get(), kind: VariableRefKind::kRead);
586 return var;
587 }
588 }
589 break;
590
591 case Operator::Kind::SLASHEQ:
592 if (ConstantFolder::IsConstantSplat(expr: right, value: 1.0)) { // x /= 1
593 if (std::unique_ptr<Expression> var = cast_expression(context, pos, expr: left,
594 type: resultType)) {
595 Analysis::UpdateVariableRefKind(expr: var.get(), kind: VariableRefKind::kRead);
596 return var;
597 }
598 }
599 if (std::unique_ptr<Expression> expr = make_reciprocal_expression(context, right)) {
600 return BinaryExpression::Make(context, pos, left: left.clone(), op: Operator::Kind::STAREQ,
601 right: std::move(expr));
602 }
603 break;
604
605 default:
606 break;
607 }
608
609 return nullptr;
610}
611
612// The expression must be scalar, and represents the right-hand side of a division op. It can
613// contain anything, not just literal values. This returns the binary expression `1.0 / expr`. The
614// expression might be further simplified by the constant folding, if possible.
615static std::unique_ptr<Expression> one_over_scalar(const Context& context,
616 const Expression& right) {
617 SkASSERT(right.type().isScalar());
618 Position pos = right.fPosition;
619 return BinaryExpression::Make(context, pos,
620 left: Literal::Make(pos, value: 1.0, type: &right.type()),
621 op: Operator::Kind::SLASH,
622 right: right.clone());
623}
624
625static std::unique_ptr<Expression> simplify_matrix_division(const Context& context,
626 Position pos,
627 const Expression& left,
628 Operator op,
629 const Expression& right,
630 const Type& resultType) {
631 // Convert matrix-over-scalar `x /= y` into `x *= (1.0 / y)`. This generates better
632 // code in SPIR-V and Metal, and should be roughly equivalent elsewhere.
633 switch (op.kind()) {
634 case OperatorKind::SLASH:
635 case OperatorKind::SLASHEQ:
636 if (left.type().isMatrix() && right.type().isScalar()) {
637 Operator multiplyOp = op.isAssignment() ? OperatorKind::STAREQ
638 : OperatorKind::STAR;
639 return BinaryExpression::Make(context, pos,
640 left: left.clone(),
641 op: multiplyOp,
642 right: one_over_scalar(context, right));
643 }
644 break;
645
646 default:
647 break;
648 }
649
650 return nullptr;
651}
652
653static std::unique_ptr<Expression> fold_expression(Position pos,
654 double result,
655 const Type* resultType) {
656 if (resultType->isNumber()) {
657 if (result >= resultType->minimumValue() && result <= resultType->maximumValue()) {
658 // This result will fit inside its type.
659 } else {
660 // The value is outside the range or is NaN (all if-checks fail); do not optimize.
661 return nullptr;
662 }
663 }
664
665 return Literal::Make(pos, value: result, type: resultType);
666}
667
668std::unique_ptr<Expression> ConstantFolder::Simplify(const Context& context,
669 Position pos,
670 const Expression& leftExpr,
671 Operator op,
672 const Expression& rightExpr,
673 const Type& resultType) {
674 // Replace constant variables with their literal values.
675 const Expression* left = GetConstantValueForVariable(inExpr: leftExpr);
676 const Expression* right = GetConstantValueForVariable(inExpr: rightExpr);
677
678 // If this is the assignment operator, and both sides are the same trivial expression, this is
679 // self-assignment (i.e., `var = var`) and can be reduced to just a variable reference (`var`).
680 // This can happen when other parts of the assignment are optimized away.
681 if (op.kind() == Operator::Kind::EQ && Analysis::IsSameExpressionTree(left: *left, right: *right)) {
682 return right->clone(pos);
683 }
684
685 // Simplify the expression when both sides are constant Boolean literals.
686 if (left->isBoolLiteral() && right->isBoolLiteral()) {
687 bool leftVal = left->as<Literal>().boolValue();
688 bool rightVal = right->as<Literal>().boolValue();
689 bool result;
690 switch (op.kind()) {
691 case Operator::Kind::LOGICALAND: result = leftVal && rightVal; break;
692 case Operator::Kind::LOGICALOR: result = leftVal || rightVal; break;
693 case Operator::Kind::LOGICALXOR: result = leftVal ^ rightVal; break;
694 case Operator::Kind::EQEQ: result = leftVal == rightVal; break;
695 case Operator::Kind::NEQ: result = leftVal != rightVal; break;
696 default: return nullptr;
697 }
698 return Literal::MakeBool(context, pos, value: result);
699 }
700
701 // If the left side is a Boolean literal, apply short-circuit optimizations.
702 if (left->isBoolLiteral()) {
703 return short_circuit_boolean(pos, left: *left, op, right: *right);
704 }
705
706 // If the right side is a Boolean literal...
707 if (right->isBoolLiteral()) {
708 // ... and the left side has no side effects...
709 if (!Analysis::HasSideEffects(expr: *left)) {
710 // We can reverse the expressions and short-circuit optimizations are still valid.
711 return short_circuit_boolean(pos, left: *right, op, right: *left);
712 }
713
714 // We can't use short-circuiting, but we can still optimize away no-op Boolean expressions.
715 return eliminate_no_op_boolean(pos, left: *left, op, right: *right);
716 }
717
718 if (op.kind() == Operator::Kind::EQEQ && Analysis::IsSameExpressionTree(left: *left, right: *right)) {
719 // With == comparison, if both sides are the same trivial expression, this is self-
720 // comparison and is always true. (We are not concerned with NaN.)
721 return Literal::MakeBool(context, pos, /*value=*/true);
722 }
723
724 if (op.kind() == Operator::Kind::NEQ && Analysis::IsSameExpressionTree(left: *left, right: *right)) {
725 // With != comparison, if both sides are the same trivial expression, this is self-
726 // comparison and is always false. (We are not concerned with NaN.)
727 return Literal::MakeBool(context, pos, /*value=*/false);
728 }
729
730 if (error_on_divide_by_zero(context, pos, op, right: *right)) {
731 return nullptr;
732 }
733
734 // Perform full constant folding when both sides are compile-time constants.
735 const Type& leftType = left->type();
736 const Type& rightType = right->type();
737 bool leftSideIsConstant = Analysis::IsCompileTimeConstant(expr: *left);
738 bool rightSideIsConstant = Analysis::IsCompileTimeConstant(expr: *right);
739
740 if (leftSideIsConstant && rightSideIsConstant) {
741 // Handle pairs of integer literals.
742 if (left->isIntLiteral() && right->isIntLiteral()) {
743 using SKSL_UINT = uint64_t;
744 SKSL_INT leftVal = left->as<Literal>().intValue();
745 SKSL_INT rightVal = right->as<Literal>().intValue();
746
747 // Note that fold_expression returns null if the result would overflow its type.
748 #define RESULT(Op) fold_expression(pos, (SKSL_INT)(leftVal) Op \
749 (SKSL_INT)(rightVal), &resultType)
750 #define URESULT(Op) fold_expression(pos, (SKSL_INT)((SKSL_UINT)(leftVal) Op \
751 (SKSL_UINT)(rightVal)), &resultType)
752 switch (op.kind()) {
753 case Operator::Kind::PLUS: return URESULT(+);
754 case Operator::Kind::MINUS: return URESULT(-);
755 case Operator::Kind::STAR: return URESULT(*);
756 case Operator::Kind::SLASH:
757 if (leftVal == std::numeric_limits<SKSL_INT>::min() && rightVal == -1) {
758 context.fErrors->error(position: pos, msg: "arithmetic overflow");
759 return nullptr;
760 }
761 return RESULT(/);
762 case Operator::Kind::PERCENT:
763 if (leftVal == std::numeric_limits<SKSL_INT>::min() && rightVal == -1) {
764 context.fErrors->error(position: pos, msg: "arithmetic overflow");
765 return nullptr;
766 }
767 return RESULT(%);
768 case Operator::Kind::BITWISEAND: return RESULT(&);
769 case Operator::Kind::BITWISEOR: return RESULT(|);
770 case Operator::Kind::BITWISEXOR: return RESULT(^);
771 case Operator::Kind::EQEQ: return RESULT(==);
772 case Operator::Kind::NEQ: return RESULT(!=);
773 case Operator::Kind::GT: return RESULT(>);
774 case Operator::Kind::GTEQ: return RESULT(>=);
775 case Operator::Kind::LT: return RESULT(<);
776 case Operator::Kind::LTEQ: return RESULT(<=);
777 case Operator::Kind::SHL:
778 if (rightVal >= 0 && rightVal <= 31) {
779 // Left-shifting a negative (or really, any signed) value is undefined
780 // behavior in C++, but not in GLSL. Do the shift on unsigned values to avoid
781 // triggering an UBSAN error.
782 return URESULT(<<);
783 }
784 context.fErrors->error(position: pos, msg: "shift value out of range");
785 return nullptr;
786 case Operator::Kind::SHR:
787 if (rightVal >= 0 && rightVal <= 31) {
788 return RESULT(>>);
789 }
790 context.fErrors->error(position: pos, msg: "shift value out of range");
791 return nullptr;
792
793 default:
794 return nullptr;
795 }
796 #undef RESULT
797 #undef URESULT
798 }
799
800 // Handle pairs of floating-point literals.
801 if (left->isFloatLiteral() && right->isFloatLiteral()) {
802 SKSL_FLOAT leftVal = left->as<Literal>().floatValue();
803 SKSL_FLOAT rightVal = right->as<Literal>().floatValue();
804
805 #define RESULT(Op) fold_expression(pos, leftVal Op rightVal, &resultType)
806 switch (op.kind()) {
807 case Operator::Kind::PLUS: return RESULT(+);
808 case Operator::Kind::MINUS: return RESULT(-);
809 case Operator::Kind::STAR: return RESULT(*);
810 case Operator::Kind::SLASH: return RESULT(/);
811 case Operator::Kind::EQEQ: return RESULT(==);
812 case Operator::Kind::NEQ: return RESULT(!=);
813 case Operator::Kind::GT: return RESULT(>);
814 case Operator::Kind::GTEQ: return RESULT(>=);
815 case Operator::Kind::LT: return RESULT(<);
816 case Operator::Kind::LTEQ: return RESULT(<=);
817 default: return nullptr;
818 }
819 #undef RESULT
820 }
821
822 // Perform matrix multiplication.
823 if (op.kind() == Operator::Kind::STAR) {
824 if (leftType.isMatrix() && rightType.isMatrix()) {
825 return simplify_matrix_times_matrix(context, pos, left: *left, right: *right);
826 }
827 if (leftType.isVector() && rightType.isMatrix()) {
828 return simplify_vector_times_matrix(context, pos, left: *left, right: *right);
829 }
830 if (leftType.isMatrix() && rightType.isVector()) {
831 return simplify_matrix_times_vector(context, pos, left: *left, right: *right);
832 }
833 }
834
835 // Perform constant folding on pairs of vectors/matrices.
836 if (is_vec_or_mat(type: leftType) && leftType.matches(other: rightType)) {
837 return simplify_componentwise(context, pos, left: *left, op, right: *right);
838 }
839
840 // Perform constant folding on vectors/matrices against scalars, e.g.: half4(2) + 2
841 if (rightType.isScalar() && is_vec_or_mat(type: leftType) &&
842 leftType.componentType().matches(other: rightType)) {
843 return simplify_componentwise(context, pos,
844 left: *left, op, right: *splat_scalar(context, scalar: *right, type: left->type()));
845 }
846
847 // Perform constant folding on scalars against vectors/matrices, e.g.: 2 + half4(2)
848 if (leftType.isScalar() && is_vec_or_mat(type: rightType) &&
849 rightType.componentType().matches(other: leftType)) {
850 return simplify_componentwise(context, pos,
851 left: *splat_scalar(context, scalar: *left, type: right->type()), op, right: *right);
852 }
853
854 // Perform constant folding on pairs of matrices, arrays or structs.
855 if ((leftType.isMatrix() && rightType.isMatrix()) ||
856 (leftType.isArray() && rightType.isArray()) ||
857 (leftType.isStruct() && rightType.isStruct())) {
858 return simplify_constant_equality(context, pos, left: *left, op, right: *right);
859 }
860 }
861
862 if (context.fConfig->fSettings.fOptimize) {
863 // If just one side is constant, we might still be able to simplify arithmetic expressions
864 // like `x * 1`, `x *= 1`, `x + 0`, `x * 0`, `0 / x`, etc.
865 if (leftSideIsConstant || rightSideIsConstant) {
866 if (std::unique_ptr<Expression> expr = simplify_arithmetic(context, pos, left: *left, op,
867 right: *right, resultType)) {
868 return expr;
869 }
870 }
871
872 // We can simplify some forms of matrix division even when neither side is constant.
873 if (std::unique_ptr<Expression> expr = simplify_matrix_division(context, pos, left: *left, op,
874 right: *right, resultType)) {
875 return expr;
876 }
877 }
878
879 // We aren't able to constant-fold.
880 return nullptr;
881}
882
883} // namespace SkSL
884

source code of flutter_engine/third_party/skia/src/sksl/SkSLConstantFolder.cpp