# This Source Code Form is subject to the terms of the Mozilla Public
# License, v. 2.0. If a copy of the MPL was not distributed with this
# file, You can obtain one at http://mozilla.org/MPL/2.0/.

# This script generates jit/LIROpsGenerated.h (list of LIR instructions)
# from LIROps.yaml.

import io
from itertools import groupby
from operator import itemgetter

import buildconfig
import yaml
from mozbuild.preprocessor import Preprocessor

HEADER_TEMPLATE = """\
/* This Source Code Form is subject to the terms of the Mozilla Public
 * License, v. 2.0. If a copy of the MPL was not distributed with this
 * file, You can obtain one at http://mozilla.org/MPL/2.0/. */

#ifndef %(includeguard)s
#define %(includeguard)s

/* This file is generated by jit/GenerateLIRFiles.py. Do not edit! */

%(contents)s

#endif // %(includeguard)s
"""


def load_yaml(yaml_path):
    # First invoke preprocessor.py so that we can use #ifdef JS_SIMULATOR in
    # the YAML file.
    pp = Preprocessor()
    pp.context.update(buildconfig.defines["ALLDEFINES"])
    pp.out = io.StringIO()
    pp.do_filter("substitution")
    pp.do_include(yaml_path)
    contents = pp.out.getvalue()
    return yaml.safe_load(contents)


def generate_header(c_out, includeguard, contents):
    c_out.write(
        HEADER_TEMPLATE
        % {
            "includeguard": includeguard,
            "contents": contents,
        }
    )


operand_types = {
    "WordSized": "LAllocation",
    "BoxedValue": "LBoxAllocation",
    "Int64": "LInt64Allocation",
}


result_types = {
    "WordSized": "1",
    "BoxedValue": "BOX_PIECES",
    "Int64": "INT64_PIECES",
}


# Generate the index expression for a BoxedValue operand.
#
# The expression has the form |num_operands + index * BOX_PIECES|, with zero
# terms being omitted.
def make_boxed_index(index, reg_operands):
    num_operands = len(reg_operands)

    expr = []
    if num_operands:
        expr.append(f"{num_operands}")
    if index:
        expr.append(f"{index} * BOX_PIECES")
    return " + ".join(expr) if expr else "0"


# Generate the index expression for an Int64 operand.
#
# The expression has the form
# |num_operands + num_value_operands * BOX_PIECES + index * INT64_PIECES|, with
# zero terms being omitted.
def make_int64_index(index, reg_operands, value_operands):
    num_operands = len(reg_operands)
    num_value_operands = len(value_operands)

    expr = []
    if num_operands:
        expr.append(f"{num_operands}")
    if num_value_operands:
        expr.append(f"{num_value_operands} * BOX_PIECES")
    if index:
        expr.append(f"{index} * INT64_PIECES")
    return " + ".join(expr) if expr else "0"


def gen_operands(operands, defer_init):
    # Group operands by operand type.
    sorted_operands = {
        k: [op for op, _ in v]
        for k, v in groupby(sorted(operands.items(), key=itemgetter(1)), itemgetter(1))
    }

    # Exactly three operand types are supported: WordSized, BoxedValue, and Int64.
    if len(sorted_operands) > 3:
        raise Exception("Invalid operand type: " + str(sorted_operands.keys()))

    reg_operands = sorted_operands.get("WordSized", [])
    value_operands = sorted_operands.get("BoxedValue", [])
    int64_operands = sorted_operands.get("Int64", [])

    # Operand index definitions.
    indices = []

    # Parameters for the class constructor.
    params = []

    # Initializer instructions for constructor body.
    initializers = []

    # Getter definitions.
    getters = []

    # Setter definitions.
    setters = []

    # Constructor parameters are generated in the order defined in the YAML file.
    if not defer_init:
        for operand, op_type in operands.items():
            params.append(f"const {operand_types[op_type]}& {operand}")

    # First initialize all word-sized operands.
    for index, operand in enumerate(reg_operands):
        cap_operand = operand[0].upper() + operand[1:]
        index_value = cap_operand + "Index"
        init_expr = f"setOperand({index_value}, {operand});"

        indices.append(f"static constexpr size_t {index_value} = {index};")
        if not defer_init:
            initializers.append(init_expr)
        else:
            setters.append(
                f"void set{cap_operand}(const LAllocation& {operand}) {{ {init_expr} }}"
            )
        getters.append(
            f"const LAllocation* {operand}() const {{ return getOperand({index_value}); }}"
        )

    # Next initialize all BoxedValue operands.
    for box_index, operand in enumerate(value_operands):
        cap_operand = operand[0].upper() + operand[1:]
        index_value = cap_operand + "Index"
        init_expr = f"setBoxOperand({index_value}, {operand});"

        indices.append(
            f"static constexpr size_t {index_value} = {make_boxed_index(box_index, reg_operands)};"
        )
        if not defer_init:
            initializers.append(init_expr)
        else:
            setters.append(
                f"void {cap_operand}(const LBoxAllocation& {operand}) {{ {init_expr} }}"
            )
        getters.append(
            f"LBoxAllocation {operand}() const {{ return getBoxOperand({index_value}); }}"
        )

    # Finally initialize all Int64 operands.
    for int64_index, operand in enumerate(int64_operands):
        cap_operand = operand[0].upper() + operand[1:]
        index_value = cap_operand + "Index"
        init_expr = f"setInt64Operand({index_value}, {operand});"

        indices.append(
            f"static constexpr size_t {index_value} = {make_int64_index(int64_index, reg_operands, value_operands)};"
        )
        if not defer_init:
            initializers.append(init_expr)
        else:
            setters.append(
                f"void set{cap_operand}(const LInt64Allocation& {operand}) {{ {init_expr} }}"
            )
        getters.append(
            f"LInt64Allocation {operand}() const {{ return getInt64Operand({index_value}); }}"
        )

    # Total number of operands.
    num_operands = f"{len(reg_operands)}"
    if value_operands:
        num_operands += f" + {len(value_operands)} * BOX_PIECES"
    if int64_operands:
        num_operands += f" + {len(int64_operands)} * INT64_PIECES"

    return (
        num_operands,
        indices,
        params,
        initializers,
        getters,
        setters,
    )


def gen_arguments(arguments):
    # Class member definitions.
    members = []

    # Parameters for the class constructor.
    params = []

    # Initializer instructions for the class constructor.
    initializers = []

    # Getter definitions.
    getters = []

    for arg_name in arguments:
        arg_type_sig = arguments[arg_name]

        members.append(f"{arg_type_sig} {arg_name}_;")
        params.append(f"{arg_type_sig} {arg_name}")
        initializers.append(f"{arg_name}_({arg_name})")
        getters.append(f"{arg_type_sig} {arg_name}() const {{ return {arg_name}_; }}")

    return (members, params, initializers, getters)


def gen_temps(num_temps, num_temps64, defer_init):
    # Parameters for the class constructor.
    params = []

    # Initializer instructions for constructor body.
    initializers = []

    # Getter definitions.
    getters = []

    # Setter definitions.
    setters = []

    for temp in range(num_temps):
        param_decl = f"const LDefinition& temp{temp}"
        init_expr = f"setTemp({temp}, temp{temp});"

        if not defer_init:
            params.append(param_decl)
            initializers.append(init_expr)
        else:
            initializers.append(f"setTemp({temp}, LDefinition::BogusTemp());")
            setters.append(f"void setTemp{temp}({param_decl}) {{ {init_expr} }}")
        getters.append(f"const LDefinition* temp{temp}() {{ return getTemp({temp}); }}")

    for int64_temp in range(num_temps64):
        temp = num_temps + int64_temp
        temp_index = f"{num_temps} + {int64_temp} * INT64_PIECES"
        param_decl = f"const LInt64Definition& temp{temp}"
        init_expr = f"setInt64Temp({temp_index}, temp{temp});"

        if not defer_init:
            params.append(param_decl)
            initializers.append(init_expr)
        else:
            initializers.append(f"setTemp({temp}, LInt64Definition::BogusTemp());")
            setters.append(f"void setTemp{temp}({param_decl}) {{ {init_expr} }}")
        getters.append(
            f"LInt64Definition temp{temp}() {{ return getInt64Temp({temp_index}); }}"
        )

    # Total number of temps.
    num_temps_total = f"{num_temps}"
    if num_temps64:
        num_temps_total += f" + {num_temps64} * INT64_PIECES"

    return (num_temps_total, params, initializers, getters, setters)


def gen_successors(successors):
    # Parameters for the class constructor.
    params = []

    # Initializer instructions for constructor body.
    initializers = []

    # Getter definitions.
    getters = []

    for index, successor in enumerate(successors or []):
        params.append(f"MBasicBlock* {successor}")
        initializers.append(f"setSuccessor({index}, {successor});")
        getters.append(
            f"MBasicBlock* {successor}() const {{ return getSuccessor({index}); }}"
        )

    return (params, initializers, getters)


def gen_lir_class(
    name,
    result_type,
    successors,
    operands,
    arguments,
    num_temps,
    num_temps64,
    call_instruction,
    mir_op,
    extra_name,
    defer_init,
):
    """Generates class definition for a single LIR opcode."""
    class_name = "L" + name

    (
        num_operands,
        oper_indices,
        oper_params,
        oper_initializers,
        oper_getters,
        oper_setters,
    ) = gen_operands(operands, defer_init)

    args_members, args_params, args_initializers, args_getters = gen_arguments(
        arguments
    )

    num_temps_total, temp_params, temp_initializers, temp_getters, temp_setters = (
        gen_temps(num_temps, num_temps64, defer_init)
    )

    succ_params, succ_initializers, succ_getters = gen_successors(successors)

    if successors is not None:
        if result_type:
            raise Exception("Control instructions don't return a result")
        num_defs = len(successors)
        parent_class = "LControlInstructionHelper"
    else:
        num_defs = result_types[result_type] if result_type else "0"
        parent_class = "LInstructionHelper"

    constructor_init = ""
    if call_instruction:
        constructor_init += "this->setIsCall();"

    mir_accessor = ""
    if mir_op:
        mir_name = name if mir_op is True else mir_op
        mir_accessor = f"M{mir_name}* mir() const {{ return mir_->to{mir_name}(); }};"

    extra_name_decl = ""
    if extra_name:
        extra_name_decl = "inline const char* extraName() const;"

    # Can be moved into the f-string when we use Python 3.12, see PEP 701.
    def nl(ws):
        return "\n" + ws

    code = f"""
class {class_name} : public {parent_class}<{num_defs}, {num_operands}, {num_temps_total}> {{
  {nl("  ").join(args_members)}

 public:
  LIR_HEADER({name})

  {nl("  ").join(oper_indices)}

  explicit {class_name}({", ".join(succ_params + oper_params + temp_params + args_params)}) : {parent_class}(classOpcode){", ".join([""] + args_initializers)} {{
    {constructor_init}
    {nl("    ").join(succ_initializers)}
    {nl("    ").join(oper_initializers)}
    {nl("    ").join(temp_initializers)}
  }}

  {nl("  ").join(succ_getters)}
  {nl("  ").join(oper_getters)}
  {nl("  ").join(oper_setters)}
  {nl("  ").join(temp_getters)}
  {nl("  ").join(temp_setters)}
  {nl("  ").join(args_getters)}
  {mir_accessor}
  {extra_name_decl}
}};
"""

    # Remove blank lines and add backslashes at line endings.
    return "\\\n".join(line for line in code.splitlines() if line.strip())


def mir_type_to_lir_type(mir_type):
    if mir_type == "Value":
        return "BoxedValue"

    if mir_type == "Int64":
        return "Int64"

    return "WordSized"


def generate_lir_header(c_out, yaml_path, mir_yaml_path):
    data = load_yaml(yaml_path)

    # LIR_OPCODE_LIST opcode.
    ops = []

    # Generated LIR op class definitions.
    lir_op_classes = []

    for op in data:
        name = op["name"]

        gen_boilerplate = op.get("gen_boilerplate", True)
        assert isinstance(gen_boilerplate, bool)

        if gen_boilerplate:
            result_type = op.get("result_type", None)
            assert result_type is None or result_type in result_types

            successors = op.get("successors", None)
            assert successors is None or isinstance(successors, list)

            operands = op.get("operands") or {}
            assert isinstance(operands, dict)

            arguments = op.get("arguments") or {}
            assert isinstance(arguments, dict)

            num_temps = op.get("num_temps", 0)
            assert isinstance(num_temps, int)

            num_temps64 = op.get("num_temps64", 0)
            assert isinstance(num_temps64, int)

            gen_boilerplate = op.get("gen_boilerplate", True)
            assert isinstance(gen_boilerplate, bool)

            call_instruction = op.get("call_instruction", None)
            assert isinstance(call_instruction, (type(None), bool))

            mir_op = op.get("mir_op", None)
            assert mir_op in (None, True) or isinstance(mir_op, str)

            extra_name = op.get("extra_name", False)
            assert isinstance(extra_name, bool)

            defer_init = op.get("defer_init", False)
            assert isinstance(defer_init, bool)

            lir_op_classes.append(
                gen_lir_class(
                    name,
                    result_type,
                    successors,
                    operands,
                    arguments,
                    num_temps,
                    num_temps64,
                    call_instruction,
                    mir_op,
                    extra_name,
                    defer_init,
                )
            )

        ops.append(f"_({name})")

    # Generate LIR instructions for MIR instructions with 'generate_lir': true
    mir_data = load_yaml(mir_yaml_path)

    for op in mir_data:
        name = op["name"]

        generate_lir = op.get("generate_lir", False)
        assert isinstance(generate_lir, bool)

        if generate_lir:
            result_type = op.get("result_type", None)
            assert isinstance(result_type, (type(None), str))

            if result_type:
                result_type = mir_type_to_lir_type(result_type)
                assert result_type in result_types

            successors = None

            operands_raw = op.get("operands", {})
            assert isinstance(operands_raw, dict)

            operands = {op: mir_type_to_lir_type(ty) for op, ty in operands_raw.items()}

            arguments = {}

            num_temps = op.get("lir_temps", 0)
            assert isinstance(num_temps, int)

            num_temps64 = op.get("lir_temps64", 0)
            assert isinstance(num_temps64, int)

            call_instruction = op.get("possibly_calls", None)
            assert isinstance(call_instruction, (type(None), bool))

            mir_op = True

            extra_name = False

            defer_init = False

            lir_op_classes.append(
                gen_lir_class(
                    name,
                    result_type,
                    successors,
                    operands,
                    arguments,
                    num_temps,
                    num_temps64,
                    call_instruction,
                    mir_op,
                    extra_name,
                    defer_init,
                )
            )

            ops.append(f"_({name})")

    contents = "#define LIR_OPCODE_LIST(_)\\\n"
    contents += "\\\n".join(ops)
    contents += "\n\n"

    contents += "#define LIR_OPCODE_CLASS_GENERATED \\\n"
    contents += "\\\n".join(lir_op_classes)
    contents += "\n\n"

    generate_header(c_out, "jit_LIROpsGenerated_h", contents)
