#!/usr/bin/env python3
# GStreamer
# Copyright (C) 2025 Seungha Yang <seungha@centricular.com>
#
# This library is free software; you can redistribute it and/or
# modify it under the terms of the GNU Library General Public
# License as published by the Free Software Foundation; either
# version 2 of the License, or (at your option) any later version.
#
# This library is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
# Library General Public License for more details.
#
# You should have received a copy of the GNU Library General Public
# License along with this library; if not, write to the
# Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
# Boston, MA 02110-1301, USA.

import sys
import os
import argparse

start_header = """/*
 * This file is autogenerated by collect_hsaco_headers.py
 */
#pragma once

"""

start_map = """
#define MAKE_BYTECODE(name) { G_STRINGIFY (name), g_##name }
static std::unordered_map<std::string, const unsigned char *>
"""

end_map = """};
#undef MAKE_BYTECODE
"""

def convert_hsaco_to_header(hsaco_file, header_file):
    with open(hsaco_file, 'rb') as f:
        hsaco_content = f.read()

    header_lines = []
    header_lines.append("// Generated by collect_hsaco_headers.py")
    header_lines.append("#pragma once")
    header_lines.append("/* Generated by bin2header.py */")
    header_lines.append("static const unsigned char g_{}[] = {{".format(os.path.splitext(os.path.basename(hsaco_file))[0]))

    bytes_per_line = 12
    for i in range(0, len(hsaco_content), bytes_per_line):
        chunk = hsaco_content[i:i+bytes_per_line]
        line = "  " + ", ".join("0x{:02x}".format(b) for b in chunk)
        if i + bytes_per_line < len(hsaco_content):
            line += ","
        header_lines.append(line)

    header_lines.append("};")
    header_lines.append("")
    header_content = "\n".join(header_lines)

    with open(header_file, "w", encoding='utf8') as f:
        f.write(header_content)

def main(args):
    parser = argparse.ArgumentParser(description='Read HIP HSACO from directory and make single header')
    parser.add_argument("--input", help="the precompiled HIP HSACO directory")
    parser.add_argument("--output", help="output header file location")
    parser.add_argument("--prefix", help="HIP HSACO header filename prefix")
    parser.add_argument("--name", help="Hash map variable name")

    args = parser.parse_args(args)

    hsaco_files = [os.path.join(args.input, file) for file in os.listdir(args.input) if file.startswith(args.prefix) and file.endswith(".hsaco") ]

    with open(args.output, 'w', newline='\n', encoding='utf8') as f:
        f.write(start_header)
        for hsaco_file in hsaco_files:
            header_file = os.path.splitext(hsaco_file)[0] + '.h'
            convert_hsaco_to_header(hsaco_file, header_file)
            f.write("#include \"")
            f.write(os.path.basename(header_file))
            f.write("\"\n")
        f.write(start_map)
        f.write(args.name)
        f.write(" = {\n")
        for hsaco_file in hsaco_files:
            f.write("  MAKE_BYTECODE ({}),\n".format(os.path.splitext(os.path.basename(hsaco_file))[0]))
        f.write(end_map)

if __name__ == "__main__":
    sys.exit(main(sys.argv[1:]))
