/*
 * Copyright (c) 2024 Lynne <dev@lynne.ee>
 *
 * This file is part of FFmpeg.
 *
 * FFmpeg is free software; you can redistribute it and/or
 * modify it under the terms of the GNU Lesser General Public
 * License as published by the Free Software Foundation; either
 * version 2.1 of the License, or (at your option) any later version.
 *
 * FFmpeg is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * Lesser General Public License for more details.
 *
 * You should have received a copy of the GNU Lesser General Public
 * License along with FFmpeg; if not, write to the Free Software
 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
 */

#ifndef VULKAN_COMMON_H
#define VULKAN_COMMON_H

#pragma use_vulkan_memory_model

layout (local_size_x_id = 253, local_size_y_id = 254, local_size_z_id = 255) in;

#ifdef DEBUG
#extension GL_EXT_debug_printf : require
#define printf debugPrintfEXT
#endif

#extension GL_EXT_shader_explicit_arithmetic_types : require
#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
#extension GL_EXT_shader_explicit_arithmetic_types_int64 : require
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
#extension GL_EXT_shader_explicit_arithmetic_types_float32 : require
#extension GL_EXT_shader_explicit_arithmetic_types_float64 : require
#extension GL_EXT_shader_8bit_storage : require
#extension GL_EXT_shader_16bit_storage : require

#extension GL_EXT_shader_image_load_formatted : require
#extension GL_EXT_nonuniform_qualifier : require
#extension GL_EXT_scalar_block_layout : require
#extension GL_EXT_buffer_reference : require
#extension GL_EXT_buffer_reference2 : require
#extension GL_KHR_memory_scope_semantics : require
#extension GL_EXT_null_initializer : require

#extension GL_EXT_expect_assume : enable
#extension GL_EXT_control_flow_attributes : enable

layout(buffer_reference, buffer_reference_align = 1) buffer u8buf {
    uint8_t v;
};

layout(buffer_reference, buffer_reference_align = 1) buffer u8vec2buf {
    u8vec2 v;
};

layout(buffer_reference, buffer_reference_align = 1) buffer u8vec4buf {
    u8vec4 v;
};

layout(buffer_reference, buffer_reference_align = 2) buffer u16buf {
    uint16_t v;
};

layout(buffer_reference, buffer_reference_align = 4) buffer u32buf {
    uint32_t v;
};

layout(buffer_reference, buffer_reference_align = 4) buffer u32vec2buf {
    u32vec2 v;
};

layout(buffer_reference, buffer_reference_align = 4) buffer u32vec4buf {
    u32vec4 v;
};

layout(buffer_reference, buffer_reference_align = 8) buffer u64buf {
    uint64_t v;
};

#define U8(x)  uint8_t(x)
#define U16(x) uint16_t(x)
#define U32(x) uint32_t(x)
#define U64(x) uint64_t(x)

#define I8(x)  int8_t(x)
#define I16(x) int16_t(x)
#define I32(x) int32_t(x)
#define I64(x) int64_t(x)

#define OFFBUF(type, b, l) \
    type(uint64_t(b) + uint64_t(l))

#define zero_extend(a, p) \
    ((a) & ((1 << (p)) - 1))

#define sign_extend(val, bits) \
    bitfieldExtract(val, 0, bits)

#define fold(diff, bits) \
    sign_extend(diff, bits)

#define mid_pred(a, b, c) \
    max(min((a), (b)), min(max((a), (b)), (c)))

#define ceil_rshift(a, b) \
    (-((-(a)) >> (b)))

/* TODO: optimize */
uint align(uint src, uint a)
{
    uint res = src % a;
    if (res == 0)
        return src;
    return src + a - res;
}

/* TODO: optimize */
uint64_t align64(uint64_t src, uint64_t a)
{
    uint64_t res = src % a;
    if (res == 0)
        return src;
    return src + a - res;
}

#define reverse2(src) \
    (pack16(unpack8(uint16_t(src)).yx))

#define reverse4(src) \
    (pack32(unpack8(uint32_t(src)).wzyx))

u32vec2 reverse8(uint64_t src)
{
    u32vec2 tmp = unpack32(src);
    tmp.x = reverse4(tmp.x);
    tmp.y = reverse4(tmp.y);
    return tmp.yx;
}

#ifdef PB_32
#define BIT_BUF_TYPE uint32_t
#define BUF_TYPE u32buf
#define BUF_REVERSE(src) reverse4(src)
#define BUF_BITS uint8_t(32)
#define BUF_BYTES uint8_t(4)
#define BYTE_EXTRACT(src, byte_off) \
    (uint8_t(bitfieldExtract((src), ((byte_off) << 3), 8)))
#else
#define BIT_BUF_TYPE uint64_t
#define BUF_TYPE u32vec2buf
#define BUF_REVERSE(src) reverse8(src)
#define BUF_BITS uint8_t(64)
#define BUF_BYTES uint8_t(8)
#define BYTE_EXTRACT(src, byte_off) \
    (uint8_t(((src) >> ((byte_off) << 3)) & 0xFF))
#endif

struct PutBitContext {
    uint64_t buf_start;
    uint64_t buf;

    BIT_BUF_TYPE bit_buf;
    uint8_t bit_left;
};

void put_bits(inout PutBitContext pb, const uint32_t n, uint32_t value)
{
    if (n < pb.bit_left) {
        pb.bit_buf = (pb.bit_buf << n) | value;
        pb.bit_left -= uint8_t(n);
    } else {
        pb.bit_buf <<= pb.bit_left;
        pb.bit_buf |= (value >> (n - pb.bit_left));

#ifdef PB_UNALIGNED
        u8buf bs = u8buf(pb.buf);
        [[unroll]]
        for (uint8_t i = uint8_t(0); i < BUF_BYTES; i++)
            bs[i].v = BYTE_EXTRACT(pb.bit_buf, BUF_BYTES - uint8_t(1) - i);
#else
#ifdef DEBUG
        if ((pb.buf % BUF_BYTES) != 0)
            debugPrintfEXT("put_bits buffer is not aligned!");
#endif

        BUF_TYPE bs = BUF_TYPE(pb.buf);
        bs.v = BUF_REVERSE(pb.bit_buf);
#endif
        pb.buf = uint64_t(bs) + BUF_BYTES;

        pb.bit_left += BUF_BITS - uint8_t(n);
        pb.bit_buf = value;
    }
}

uint32_t flush_put_bits(inout PutBitContext pb)
{
    /* Align bits to MSBs */
    if (pb.bit_left < BUF_BITS)
        pb.bit_buf <<= pb.bit_left;

    if (pb.bit_left < BUF_BITS) {
        uint to_write = ((BUF_BITS - pb.bit_left - 1) >> 3) + 1;

        u8buf bs = u8buf(pb.buf);
        for (int i = 0; i < to_write; i++)
            bs[i].v = BYTE_EXTRACT(pb.bit_buf, BUF_BYTES - uint8_t(1) - i);
        pb.buf = uint64_t(bs) + to_write;
    }

    pb.bit_left = BUF_BITS;
    pb.bit_buf = 0x0;

    return uint32_t(pb.buf - pb.buf_start);
}

void init_put_bits(out PutBitContext pb, u8buf data, uint64_t len)
{
    pb.buf_start = uint64_t(data);
    pb.buf = uint64_t(data);

    pb.bit_buf = 0;
    pb.bit_left = BUF_BITS;
}

uint64_t put_bits_count(in PutBitContext pb)
{
    return (pb.buf - pb.buf_start)*8 + BUF_BITS - pb.bit_left;
}

uint32_t put_bytes_count(in PutBitContext pb)
{
    uint64_t num_bytes = (pb.buf - pb.buf_start) + ((BUF_BITS - pb.bit_left) >> 3);
    return uint32_t(num_bytes);
}

struct GetBitContext {
    uint64_t buf_start;
    uint64_t buf;
    uint64_t buf_end;

    uint64_t bits;
    int bits_valid;
#ifdef GET_BITS_SMEM
    int cur_smem_pos;
#endif
};

#ifndef GET_BITS_SMEM
#define LOAD64()                                       \
    {                                                  \
        u8vec4buf ptr = u8vec4buf(gb.buf);             \
        uint32_t rf1 = pack32((ptr[0].v).wzyx);        \
        uint32_t rf2 = pack32((ptr[1].v).wzyx);        \
        gb.buf += 8;                                   \
        gb.bits = uint64_t(rf1) << 32 | uint64_t(rf2); \
        gb.bits_valid = 64;                            \
    }

#define RELOAD32()                                                \
    {                                                             \
        u8vec4buf ptr = u8vec4buf(gb.buf);                        \
        uint32_t rf = pack32((ptr[0].v).wzyx);                    \
        gb.buf += 4;                                              \
        gb.bits = uint64_t(rf) << (32 - gb.bits_valid) | gb.bits; \
        gb.bits_valid += 32;                                      \
    }
#else /* GET_BITS_SMEM */
shared u32vec4 gb_storage[gl_WorkGroupSize.x*gl_WorkGroupSize.y*gl_WorkGroupSize.z*GET_BITS_SMEM];

#define FILL_SMEM()                                                             \
    {                                                                           \
        u32vec4buf ptr = u32vec4buf(gb.buf);                                    \
        [[unroll]]                                                              \
        for (uint i = 0; i < GET_BITS_SMEM; ++i)                                \
            gb_storage[gl_LocalInvocationIndex * GET_BITS_SMEM + i] = ptr[i].v; \
        gb.cur_smem_pos = 0;                                                    \
    }

#define LOAD64()                                                    \
    {                                                               \
        gb.bits = 0;                                                \
        gb.bits_valid = 0;                                          \
        u8buf ptr = u8buf(gb.buf);                                  \
        for (uint i = 0; i < ((4 - uint(gb.buf_start)) & 3); ++i) { \
            gb.bits |= uint64_t(ptr[i].v) << (56 - i * 8);          \
            gb.bits_valid += 8;                                     \
            gb.buf += 1;                                            \
        }                                                           \
        FILL_SMEM();                                                \
    }

#define RELOAD32()                                                                                  \
    {                                                                                               \
        if (gb.cur_smem_pos >= 4*GET_BITS_SMEM)                                                     \
            FILL_SMEM();                                                                            \
        u32vec4 vec = gb_storage[gl_LocalInvocationIndex * GET_BITS_SMEM + (gb.cur_smem_pos >> 2)]; \
        uint v = vec[gb.cur_smem_pos & 3];                                                          \
        gb.buf += 4;                                                                                \
        gb.bits = uint64_t(reverse4(v)) << (32 - gb.bits_valid) | gb.bits;                          \
        gb.bits_valid += 32;                                                                        \
        gb.cur_smem_pos += 1;                                                                       \
    }
#endif /* GET_BITS_SMEM */

void init_get_bits(inout GetBitContext gb, u8buf data, int len)
{
    gb.buf = gb.buf_start = uint64_t(data);
    gb.buf_end = uint64_t(data) + len;

    /* Preload */
    LOAD64()
}

bool get_bit(inout GetBitContext gb)
{
    if (gb.bits_valid == 0)
        LOAD64()

    bool val = bool(gb.bits >> (64 - 1));
    gb.bits <<= 1;
    gb.bits_valid--;
    return val;
}

uint get_bits(inout GetBitContext gb, int n)
{
    if (n == 0)
        return 0;

    if (n > gb.bits_valid)
        RELOAD32()

    uint val = uint(gb.bits >> (64 - n));
    gb.bits <<= n;
    gb.bits_valid -= n;
    return val;
}

uint show_bits(inout GetBitContext gb, int n)
{
    if (n > gb.bits_valid)
        RELOAD32()

    return uint(gb.bits >> (64 - n));
}

void skip_bits(inout GetBitContext gb, int n)
{
    if (n > gb.bits_valid)
        RELOAD32()

    gb.bits <<= n;
    gb.bits_valid -= n;
}

int tell_bits(in GetBitContext gb)
{
    return int(gb.buf - gb.buf_start) * 8 - gb.bits_valid;
}

int left_bits(in GetBitContext gb)
{
    return int(gb.buf_end - gb.buf) * 8 + gb.bits_valid;
}

#endif /* VULKAN_COMMON_H */
