/*
 * ProRes RAW decoder
 *
 * Copyright (c) 2025 Lynne <dev@lynne.ee>
 *
 * This file is part of FFmpeg.
 *
 * FFmpeg is free software; you can redistribute it and/or
 * modify it under the terms of the GNU Lesser General Public
 * License as published by the Free Software Foundation; either
 * version 2.1 of the License, or (at your option) any later version.
 *
 * FFmpeg is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * Lesser General Public License for more details.
 *
 * You should have received a copy of the GNU Lesser General Public
 * License along with FFmpeg; if not, write to the Free Software
 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
 */

#pragma shader_stage(compute)
#extension GL_GOOGLE_include_directive : require

#include "common.glsl"

struct TileData {
   ivec2 pos;
   uint offset;
   uint size;
};

layout (set = 0, binding = 0) uniform writeonly uimage2D dst;
layout (set = 0, binding = 1, scalar) readonly buffer frame_data_buf {
    TileData tile_data[];
};

layout (push_constant, scalar) uniform pushConstants {
   u8buf pkt_data;
   ivec2 tile_size;
};

#define COMP_ID (gl_LocalInvocationID.y)

GetBitContext gb;

#define DC_CB_MAX 12
const uint8_t dc_cb[DC_CB_MAX + 1] = {
    U8(16), U8(33), U8(50), U8(51), U8(51), U8(51),
    U8(68), U8(68), U8(68), U8(68), U8(68), U8(68), U8(118)
};

#define AC_CB_MAX 94
const int16_t ac_cb[AC_CB_MAX + 1] = {
    I16(  0), I16(529), I16(273), I16(273), I16(546), I16(546),
    I16(546), I16(290), I16(290), I16(290), I16(563), I16(563),
    I16(563), I16(563), I16(563), I16(563), I16(563), I16(563),
    I16(307), I16(307), I16(580), I16(580), I16(580), I16(580),
    I16(580), I16(580), I16(580), I16(580), I16(580), I16(580),
    I16(580), I16(580), I16(580), I16(580), I16(580), I16(580),
    I16(580), I16(580), I16(580), I16(580), I16(580), I16(580),
    I16(853), I16(853), I16(853), I16(853), I16(853), I16(853),
    I16(853), I16(853), I16(853), I16(853), I16(853), I16(853),
    I16(853), I16(853), I16(853), I16(853), I16(853), I16(853),
    I16(853), I16(853), I16(853), I16(853), I16(853), I16(853),
    I16(853), I16(853), I16(853), I16(853), I16(853), I16(853),
    I16(853), I16(853), I16(853), I16(853), I16(853), I16(853),
    I16(853), I16(853), I16(853), I16(853), I16(853), I16(853),
    I16(853), I16(853), I16(853), I16(853), I16(853), I16(853),
    I16(853), I16(853), I16(853), I16(853), I16(358)
};

#define RN_CB_MAX 27
const int16_t rn_cb[RN_CB_MAX + 1] = {
    I16(512), I16(256), I16(  0), I16(  0), I16(529), I16(529), I16(273),
    I16(273), I16( 17), I16( 17), I16( 33), I16( 33), I16(546), I16( 34),
    I16( 34), I16( 34), I16( 34), I16( 34), I16( 34), I16( 34), I16( 34),
    I16( 34), I16( 34), I16( 34), I16( 34), I16( 50), I16( 50), I16( 68),
};

#define LN_CB_MAX 14
const int16_t ln_cb[LN_CB_MAX + 1] = {
    I16( 256), I16( 273), I16( 546), I16( 546), I16( 290), I16( 290), I16( 1075),
    I16(1075), I16( 563), I16( 563), I16( 563), I16( 563), I16( 563), I16( 563),
    I16( 51)
};

int16_t get_value(int16_t codebook)
{
    const int16_t switch_bits = codebook >> 8;
    const int16_t rice_order  = codebook & I16(0xf);
    const int16_t exp_order   = (codebook >> 4) & I16(0xf);

    uint32_t b = show_bits(gb, 32);
    if (expectEXT(b == 0, false))
        return I16(0);
    int16_t q = I16(31) - I16(findMSB(b));

    if ((b & 0x80000000) != 0) {
        skip_bits(gb, 1 + rice_order);
        return I16((b & 0x7FFFFFFF) >> (31 - rice_order));
    }

    if (q <= switch_bits) {
        skip_bits(gb, q + rice_order + 1);
        return I16((q << rice_order) +
                   (((b << (q + 1)) >> 1) >> (31 - rice_order)));
    }

    int16_t bits = exp_order + (q << 1) - switch_bits;
    skip_bits(gb, bits);
    return I16((b >> (32 - bits)) +
               ((switch_bits + 1) << rice_order) -
               (1 << exp_order));
}

#define TODCCODEBOOK(x) ((x + 1) >> 1)

void store_val(ivec2 offs, int blk, int c, int16_t v)
{
    imageStore(dst, offs + 2*ivec2(blk*8 + (c & 7), c >> 3),
               ivec4(v & 0xFFFF));
}

void read_dc_vals(ivec2 offs, int nb_blocks)
{
    int16_t dc, dc_add;
    int16_t prev_dc = I16(0), sign = I16(0);

    /* Special handling for first block */
    dc = get_value(I16(700));
    prev_dc = (dc >> 1) ^ -(dc & I16(1));
    store_val(offs, 0, 0, prev_dc);

    for (int n = 1; n < nb_blocks; n++) {
        if (expectEXT(left_bits(gb) <= 0, false))
            break;

        uint8_t dc_codebook;
        if ((n & 15) == 1)
            dc_codebook = uint8_t(100);
        else
            dc_codebook = dc_cb[min(TODCCODEBOOK(dc), 13 - 1)];

        dc = get_value(dc_codebook);

        sign = sign ^ dc & int16_t(1);
        dc_add = (-sign ^ I16(TODCCODEBOOK(dc))) + sign;
        sign = I16(dc_add < 0);
        prev_dc += dc_add;

        store_val(offs, n, 0, prev_dc);
    }
}

void read_ac_vals(ivec2 offs, int nb_blocks)
{
    const int nb_codes = nb_blocks << 6;
    const int log2_nb_blocks = findMSB(nb_blocks);
    const int block_mask = (1 << log2_nb_blocks) - 1;

    int16_t ac, rn, ln;
    int16_t ac_codebook = I16(49);
    int16_t rn_codebook = I16( 0);
    int16_t ln_codebook = I16(66);
    int16_t sign;
    int16_t val;

    for (int n = nb_blocks; n <= nb_codes;) {
        if (expectEXT(left_bits(gb) <= 0, false))
            break;

        ln = get_value(ln_codebook);
        for (int i = 0; i < ln; i++) {
            if (expectEXT(left_bits(gb) <= 0, false))
                break;

            if (expectEXT(n >= nb_codes, false))
                break;

            ac = get_value(ac_codebook);
            ac_codebook = ac_cb[min(ac, 95 - 1)];
            sign = -int16_t(get_bit(gb));

            val = ((ac + I16(1)) ^ sign) - sign;
            store_val(offs, n & block_mask, n >> log2_nb_blocks, val);

            n++;
        }

        if (expectEXT(n >= nb_codes, false))
            break;

        rn = get_value(rn_codebook);
        rn_codebook = rn_cb[min(rn, 28 - 1)];

        n += rn + 1;
        if (expectEXT(n >= nb_codes, false))
            break;

        if (expectEXT(left_bits(gb) <= 0, false))
            break;

        ac = get_value(ac_codebook);
        sign = -int16_t(get_bit(gb));

        val = ((ac + I16(1)) ^ sign) - sign;
        store_val(offs, n & block_mask, n >> log2_nb_blocks, val);

        ac_codebook = ac_cb[min(ac, 95 - 1)];
        ln_codebook = ln_cb[min(ac, 15 - 1)];

        n++;
    }
}

void main(void)
{
    const uint tile_idx = gl_WorkGroupID.y*gl_NumWorkGroups.x + gl_WorkGroupID.x;
    TileData td = tile_data[tile_idx];

    int width = imageSize(dst).x;
    if (expectEXT(td.pos.x >= width, false))
        return;

    uint64_t pkt_offset = uint64_t(pkt_data) + td.offset;
    u8vec2buf hdr_data = u8vec2buf(pkt_offset);
    int header_len = hdr_data[0].v.x >> 3;

    ivec4 size = ivec4(td.size,
                       pack16(hdr_data[2].v.yx),
                       pack16(hdr_data[1].v.yx),
                       pack16(hdr_data[3].v.yx));
    size[0] = size[0] - size[1] - size[2] - size[3] - header_len;
    if (expectEXT(size[0] < 0, false))
        return;

    const ivec2 offs = td.pos + ivec2(COMP_ID & 1, COMP_ID >> 1);
    const int w = min(tile_size.x, width - td.pos.x) >> 1;
    const int nb_blocks = w >> 3;

    const ivec4 comp_offset = ivec4(size[2] + size[1] + size[3],
                                    size[2],
                                    0,
                                    size[2] + size[1]);

    init_get_bits(gb, u8buf(pkt_offset + header_len + comp_offset[COMP_ID]),
                  size[COMP_ID]);

    read_dc_vals(offs, nb_blocks);
    read_ac_vals(offs, nb_blocks);
}
