"""Constant-folding diagnostics for literal-only expressions.
Numeric diagnostic refs are limited to JSON-stable values so Python and
jsfcstm agree on values that cross the JSON boundary.
"""
import math
from typing import TYPE_CHECKING, Any, List, Optional
from ...utils.validate import ModelDiagnostic
if TYPE_CHECKING: # pragma: no cover
from ...model.expr import Expr
from ...model.model import OperationStatement, StateMachine
ConstValue = Any
_MAX_JSON_STABLE_INT = 9007199254740991
_MAX_FOLD_SHIFT_BITS = 1024
_COMPARISON_OPS = {'<', '<=', '>', '>=', '==', '!='}
[docs]
def fold_numeric_expression(expr: 'Expr') -> Optional[ConstValue]:
"""Fold one numeric expression when it contains only supported constants."""
from ...model.expr import BinaryOp, ConditionalOp, Float, Integer, UnaryOp
if isinstance(expr, Integer):
return expr.value
if isinstance(expr, Float):
return expr.value
if isinstance(expr, UnaryOp):
value = fold_numeric_expression(expr.x)
if value is None:
return None
if expr.op == '+':
return +value
if expr.op == '-':
return -value
return None
if isinstance(expr, BinaryOp):
left = fold_numeric_expression(expr.x)
right = fold_numeric_expression(expr.y)
if left is None or right is None:
return None
return _fold_numeric_binary(expr.op, left, right)
if isinstance(expr, ConditionalOp):
condition = fold_condition_expression(expr.cond)
if condition is None:
return None
return fold_numeric_expression(expr.if_true if condition else expr.if_false)
return None
[docs]
def fold_condition_expression(expr: 'Expr') -> Optional[bool]:
"""Fold one condition expression when it contains only supported constants."""
from ...model.expr import BinaryOp, Boolean, ConditionalOp, UnaryOp
if isinstance(expr, Boolean):
return expr.value
if isinstance(expr, UnaryOp):
if expr.op != '!':
return None
value = fold_condition_expression(expr.x)
return None if value is None else not value
if isinstance(expr, BinaryOp):
if expr.op in {'&&', '||'}:
left = fold_condition_expression(expr.x)
right = fold_condition_expression(expr.y)
if left is None or right is None:
return None
return left and right if expr.op == '&&' else left or right
if expr.op in _COMPARISON_OPS:
return _fold_comparison(expr)
return None
if isinstance(expr, ConditionalOp):
condition = fold_condition_expression(expr.cond)
if condition is None:
return None
return fold_condition_expression(expr.if_true if condition else expr.if_false)
return None
[docs]
def collect_const_fold_warnings(
machine: Optional['StateMachine'],
) -> List[ModelDiagnostic]:
"""Collect diagnostics that depend on constant folding."""
if machine is None:
return []
diagnostics: List[ModelDiagnostic] = []
defined_vars = set(machine.defines)
transition_indexes = _transition_indexes(machine)
for state in machine.walk_states():
state_path = _state_path(state)
for transition in state.transitions:
folded_guard = (
None if transition.guard is None
else fold_condition_expression(transition.guard)
)
if folded_guard is True:
diagnostics.append(
_guard_const_diagnostic(
transition,
True,
transition_indexes.get(id(transition)),
)
)
elif folded_guard is False:
diagnostics.append(
_guard_const_diagnostic(
transition,
False,
transition_indexes.get(id(transition)),
)
)
diagnostics.extend(
_during_const_assign_diagnostics(state_path, state, defined_vars)
)
return diagnostics
def _fold_numeric_binary(
op: str,
left: ConstValue,
right: ConstValue,
) -> Optional[ConstValue]:
if op in {'<<', '>>', '&', '^', '|'}:
if not (_is_plain_int(left) and _is_plain_int(right)):
return None
if op in {'<<', '>>'} and right < 0:
return None
if op in {'<<', '>>'} and right > _MAX_FOLD_SHIFT_BITS:
return None
if op == '<<':
return left << right
if op == '>>':
return left >> right
if op == '&':
return left & right
if op == '^':
return left ^ right
return left | right
if op in {'+', '-', '*', '/', '**'} and _has_unsafe_integer_operand(left, right):
return None
if op == '+':
return _stable_numeric_result(left + right)
if op == '-':
return _stable_numeric_result(left - right)
if op == '*':
return _stable_numeric_result(left * right)
if op == '/':
if right == 0:
return None
return _stable_numeric_result(left / right)
if op == '%':
if right == 0:
return None
if not (_is_plain_int(left) and _is_plain_int(right)) and _has_unsafe_integer_operand(left, right):
return None
return _stable_numeric_result(left % right)
if op == '**':
if left == 0 and right < 0:
return None
if (
_is_plain_int(left)
and _is_plain_int(right)
and right >= 0
and _integer_power_exceeds_json_stable_range(left, right)
):
return None
try:
result = left ** right
except (OverflowError, ValueError, ZeroDivisionError):
# OverflowError: huge float exponent; ValueError: complex result
# from fractional powers; ZeroDivisionError: 0 ** negative.
return None
if isinstance(result, complex):
return None
return _stable_numeric_result(result)
return None
def _fold_comparison(expr) -> Optional[bool]:
left_numeric = fold_numeric_expression(expr.x)
right_numeric = fold_numeric_expression(expr.y)
if left_numeric is not None and right_numeric is not None:
return _compare_values(expr.op, left_numeric, right_numeric)
if expr.op not in {'==', '!='}:
return None
left_bool = fold_condition_expression(expr.x)
right_bool = fold_condition_expression(expr.y)
if left_bool is None or right_bool is None:
return None
return left_bool == right_bool if expr.op == '==' else left_bool != right_bool
def _compare_values(op: str, left: ConstValue, right: ConstValue) -> Optional[bool]:
comparable = _comparison_operands(left, right)
if comparable is None:
return None
left, right, approximate = comparable
if op == '<':
return left < right
if op == '<=':
return left <= right
if op == '>':
return left > right
if op == '>=':
return left >= right
if op == '==':
if approximate:
return float(left) == float(right)
return left == right
if op == '!=':
if approximate:
return float(left) != float(right)
return left != right
return False # pragma: no cover
def _during_const_assign_diagnostics(state_path: str, state, defined_vars) -> List[ModelDiagnostic]:
diagnostics: List[ModelDiagnostic] = []
for action in state.on_durings:
if action.is_abstract or action.is_ref or action.aspect is not None:
continue
for stmt in action.operations:
diagnostics.extend(_during_stmt_const_assign_diagnostics(state_path, stmt, defined_vars))
return diagnostics
def _during_stmt_const_assign_diagnostics(
state_path: str,
stmt: 'OperationStatement',
defined_vars,
) -> List[ModelDiagnostic]:
from ...model.model import Operation
if not isinstance(stmt, Operation):
return []
if stmt.var_name not in defined_vars:
return []
value = _json_stable_number(fold_numeric_expression(stmt.expr))
if value is None:
return []
return [
ModelDiagnostic(
code='W_DURING_CONST_ASSIGN',
severity='warning',
message=(
f'During action in {state_path!r} assigns {stmt.var_name!r} '
'to the same constant value every cycle.'
),
span=getattr(stmt, '_span', None),
refs={
'state_path': state_path,
'var_name': stmt.var_name,
'value': value,
},
)
]
def _expr_text(expr) -> Optional[str]:
from ..inspect import _expr_text as inspect_expr_text
return inspect_expr_text(expr)
def _transition_indexes(machine: 'StateMachine') -> dict:
indexes = {}
index = 0
for state in machine.walk_states():
for transition in state.transitions:
indexes[id(transition)] = index
index += 1
return indexes
def _guard_const_diagnostic(
transition,
value: bool,
transition_index: Optional[int],
) -> ModelDiagnostic:
code = 'W_GUARD_CONST_TRUE' if value else 'W_GUARD_CONST_FALSE'
label = 'true' if value else 'false'
source_label = _transition_endpoint_label(transition.from_state)
target_label = _transition_endpoint_label(transition.to_state)
return ModelDiagnostic(
code=code,
severity='warning',
span=getattr(transition, '_span', None),
message=(
f'Transition {source_label!r} -> {target_label!r} '
f'has a guard that is statically {label}.'
),
refs={
'transition_span': getattr(transition, '_span', None),
'folded_value': value,
'from_path': _transition_endpoint_path(transition, is_source=True),
'to_path': _transition_endpoint_path(transition, is_source=False),
'guard_text': _expr_text(transition.guard),
'transition_index': transition_index,
},
)
def _transition_endpoint_path(transition, *, is_source: bool) -> str:
from ...dsl import EXIT_STATE, INIT_STATE
endpoint = transition.from_state if is_source else transition.to_state
if endpoint is INIT_STATE or endpoint is EXIT_STATE:
return '[*]'
parent = transition.parent
if parent is None: # pragma: no cover - model-built transitions always have parents.
return str(endpoint)
return '.'.join(str(part) for part in (*parent.path, endpoint) if part is not None)
def _transition_endpoint_label(value) -> str:
from ...dsl import EXIT_STATE, INIT_STATE
if value is INIT_STATE or value is EXIT_STATE:
return '[*]'
return str(value)
def _state_path(state) -> str:
return '.'.join(p for p in state.path if p is not None)
def _is_plain_int(value: ConstValue) -> bool:
return isinstance(value, int) and not isinstance(value, bool)
def _has_unsafe_integer_operand(left: ConstValue, right: ConstValue) -> bool:
return any(
_is_plain_int(value) and abs(value) > _MAX_JSON_STABLE_INT
for value in (left, right)
)
def _json_stable_number(value: ConstValue) -> Optional[ConstValue]:
if value is None or isinstance(value, bool):
return None
if isinstance(value, int):
if abs(value) <= _MAX_JSON_STABLE_INT:
return value
return None
if isinstance(value, float):
if not math.isfinite(value):
return None
if value.is_integer():
if abs(value) <= _MAX_JSON_STABLE_INT:
return int(value)
return None
return value
return None
def _stable_numeric_result(value: ConstValue) -> Optional[ConstValue]:
if value is None or isinstance(value, bool):
return None
if isinstance(value, int):
if abs(value) <= _MAX_JSON_STABLE_INT:
return value
return None
if isinstance(value, float):
if not math.isfinite(value):
return None
if value.is_integer() and abs(value) > _MAX_JSON_STABLE_INT:
return None
return value
return None
def _comparison_operands(left: ConstValue, right: ConstValue):
left_float = isinstance(left, float)
right_float = isinstance(right, float)
if not left_float and not right_float:
return left, right, False
if left_float and right_float:
if math.isfinite(left) and math.isfinite(right):
return left, right, True
return None
if _mixed_float_and_unsafe_integer(left, right):
return None
if left_float:
if _is_plain_int(right) and abs(right) <= _MAX_JSON_STABLE_INT:
return left, right, True
normalized_left = _safe_integer_float(left)
if normalized_left is None:
return None
return normalized_left, right, False
if _is_plain_int(left) and abs(left) <= _MAX_JSON_STABLE_INT:
return left, right, True
normalized_right = _safe_integer_float(right)
if normalized_right is None:
return None
return left, normalized_right, False
def _mixed_float_and_unsafe_integer(left: ConstValue, right: ConstValue) -> bool:
return (
(isinstance(left, float) and _is_plain_int(right) and abs(right) > _MAX_JSON_STABLE_INT)
or (isinstance(right, float) and _is_plain_int(left) and abs(left) > _MAX_JSON_STABLE_INT)
)
def _safe_integer_float(value: float) -> Optional[int]:
if (
math.isfinite(value)
and value.is_integer()
and abs(value) <= _MAX_JSON_STABLE_INT
):
return int(value)
return None
def _integer_power_exceeds_json_stable_range(base: int, exponent: int) -> bool:
if exponent == 0:
return False
if base in {-1, 0, 1}:
return False
return exponent * math.log2(abs(base)) > math.log2(_MAX_JSON_STABLE_INT)