/*
 * FFv1 codec
 *
 * Copyright (c) 2024 Lynne <dev@lynne.ee>
 *
 * This file is part of FFmpeg.
 *
 * FFmpeg is free software; you can redistribute it and/or
 * modify it under the terms of the GNU Lesser General Public
 * License as published by the Free Software Foundation; either
 * version 2.1 of the License, or (at your option) any later version.
 *
 * FFmpeg is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * Lesser General Public License for more details.
 *
 * You should have received a copy of the GNU Lesser General Public
 * License along with FFmpeg; if not, write to the Free Software
 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
 */

#pragma shader_stage(compute)
#extension GL_GOOGLE_include_directive : require

#define ENCODE
#include "common.glsl"
#include "ffv1_common.glsl"

layout (set = 0, binding = 2, scalar) uniform crc_ieee_buf {
    uint32_t crc_ieee[256];
};

layout (set = 1, binding = 1, scalar) writeonly buffer slice_results_buf {
    uint32_t slice_results[];
};
layout (set = 1, binding = 3) uniform uimage2D src[];

#ifndef GOLOMB

layout (set = 1, binding = 2, scalar) buffer slice_state_buf {
    uint8_t slice_rc_state[];
};

#define WRITE(idx, val) put_rac(rc_state[idx], val)
void put_symbol(int v)
{
    bool is_nil = (v == 0);
    WRITE(0, is_nil);
    if (is_nil)
        return;

    int a = abs(v);
    int e = findMSB(a);

    for (int i = 0; i < e; i++)
        WRITE(1 + min(i, 9), true);
    WRITE(1 + min(e, 9), false);

    for (int i = e - 1; i >= 0; i--)
        WRITE(22 + min(i, 9), bool(bitfieldExtract(a, i, 1)));

    WRITE(22 - 11 + min(e, 10), v < 0);
}

void encode_line_pcm(in SliceContext sc, readonly uimage2D img,
                     ivec2 sp, int y, uint p, uint comp)
{
    if (gl_LocalInvocationID.x > 0)
        return;

    int w = sc.slice_dim.x;

#ifndef RGB
    if (p > 0 && p < 3) {
        w = ceil_rshift(w, chroma_shift.x);
        sp >>= chroma_shift;
    }
#endif

    for (int x = 0; x < w; x++) {
        uint v = imageLoad(img, sp + LADDR(ivec2(x, y)))[comp];

        for (uint i = (rct_offset >> 1); i > 0; i >>= 1)
            put_rac_equi(bool(v & i));
    }
}

void encode_line(in SliceContext sc, readonly uimage2D img, uint state_off,
                 ivec2 sp, int y, uint p, uint comp,
                 uint8_t quant_table_idx, in int run_index)
{
    int w = sc.slice_dim.x;

#ifndef RGB
    if (p > 0 && p < 3) {
        w = ceil_rshift(w, chroma_shift.x);
        sp >>= chroma_shift;
    }
#endif

    linecache_load(img, sp, y, comp);

    for (int x = 0; x < w; x++) {
        ivec2 d = get_pred(img, sp, ivec2(x, y), comp, w,
                           quant_table_idx, extend_lookup[quant_table_idx]);
        TYPE cur = TYPE(imageLoad(img, sp + LADDR(ivec2(x, y)))[comp]);
        d[1] = int(cur) - d[1];

        if (d[0] < 0)
            d = -d;

        d[1] = fold(d[1], bits);

        uint rc_off = state_off + CONTEXT_SIZE*d[0] + gl_LocalInvocationID.x;

        rc_state[gl_LocalInvocationID.x] = slice_rc_state[rc_off];
        barrier();

        if (gl_LocalInvocationID.x == 0) {
            put_symbol(d[1]);
            linecache_next(cur);
        }

        barrier();
        slice_rc_state[rc_off] = rc_state[gl_LocalInvocationID.x];
    }
}

#else /* GOLOMB */

layout (set = 1, binding = 2, scalar) buffer slice_state_buf {
    VlcState slice_vlc_state[];
};

uint hdr_len = 0;
PutBitContext pb;

void init_golomb(void)
{
    hdr_len = rac_terminate();
    init_put_bits(pb, OFFBUF(u8buf, rc.bs_start, hdr_len),
                  slice_size_max - hdr_len);
}

void encode_line(in SliceContext sc, readonly uimage2D img, uint state_off,
                 ivec2 sp, int y, uint p, uint comp,
                 uint8_t quant_table_idx, inout int run_index)
{
    int w = sc.slice_dim.x;

#ifndef RGB
    if (p > 0 && p < 3) {
        w = ceil_rshift(w, chroma_shift.x);
        sp >>= chroma_shift;
    }
#endif

    linecache_load(img, sp, y, comp);

    int run_count = 0;
    bool run_mode = false;

    for (int x = 0; x < w; x++) {
        ivec2 d = get_pred(img, sp, ivec2(x, y), comp, w,
                           quant_table_idx, extend_lookup[quant_table_idx]);
        TYPE cur = TYPE(imageLoad(img, sp + LADDR(ivec2(x, y)))[comp]);
        d[1] = int(cur) - d[1];
        linecache_next(cur);

        if (d[0] < 0)
            d = -d;

        d[1] = fold(d[1], bits);

        if (d[0] == 0)
            run_mode = true;

        if (run_mode) {
            if (d[1] != 0) {
                /* A very unlikely loop */
                while (run_count >= 1 << log2_run[run_index]) {
                    run_count -= 1 << log2_run[run_index];
                    run_index++;
                    put_bits(pb, 1, 1);
                }

                put_bits(pb, 1 + log2_run[run_index], run_count);
                if (run_index != 0)
                    run_index--;
                run_count = 0;
                run_mode  = false;
                if (d[1] > 0)
                    d[1]--;
            } else {
                run_count++;
            }
        }

        if (!run_mode) {
            Symbol sym = get_vlc_symbol(slice_vlc_state[state_off + d[0]],
                                        d[1], bits);
            put_bits(pb, sym.bits, sym.val);
        }
    }

    if (run_mode) {
        while (run_count >= (1 << log2_run[run_index])) {
            run_count -= 1 << log2_run[run_index];
            run_index++;
            put_bits(pb, 1, 1);
        }

        if (run_count > 0)
            put_bits(pb, 1, 1);
    }
}
#endif

#ifdef RGB
const uvec4 rgb_plane_order = { 1, 2, 0, 3 };

ivec4 load_components(ivec2 pos)
{
    ivec4 pix = ivec4(imageLoad(src[0], pos));
    if (planar_rgb) {
        for (int i = 1; i < (3 + int(transparency)); i++)
            pix[i] = int(imageLoad(src[i], pos)[0]);
    }

    return ivec4(pix[fmt_lut[0]], pix[fmt_lut[1]],
                 pix[fmt_lut[2]], pix[fmt_lut[3]]);
}

void transform_sample(inout ivec4 pix, ivec2 rct_coef)
{
    pix.b -= pix.g;
    pix.r -= pix.g;
    pix.g += (pix.b*rct_coef.g + pix.r*rct_coef.r) >> 2;
    pix.b += rct_offset;
    pix.r += rct_offset;
}

void preload_rgb(in SliceContext sc, ivec2 sp, int w, int y, bool apply_rct)
{
    for (uint x = gl_LocalInvocationID.x; x < w; x += gl_WorkGroupSize.x) {
        ivec2 lpos = sp + LADDR(ivec2(x, y));
        ivec2 pos = sc.slice_pos + ivec2(x, y);

        ivec4 pix = load_components(pos);

        if (apply_rct)
            transform_sample(pix, sc.slice_rct_coef);

        imageStore(tmp, lpos, pix);
    }

    memoryBarrierImage();
    barrier();
}
#endif

void encode_slice(in SliceContext sc, uint slice_idx)
{
    ivec2 sp = sc.slice_pos;

#ifdef RGB
    sp.y = int(gl_WorkGroupID.y)*rgb_linecache;
#endif

#ifndef GOLOMB
    if (force_pcm) {
#ifndef RGB
        for (int c = 0; c < color_planes; c++) {

            int h = sc.slice_dim.y;
            if (c > 0 && c < 3)
                h = ceil_rshift(h, chroma_shift.y);

            /* Takes into account dual-plane YUV formats */
            int p = min(c, planes - 1);
            int comp = c - p;

            for (int y = 0; y < h; y++)
                encode_line_pcm(sc, src[p], sp, y, p, comp);
        }
#else
        for (int y = 0; y < sc.slice_dim.y; y++) {
            preload_rgb(sc, sp, sc.slice_dim.x, y, false);

            for (uint c = 0; c < color_planes; c++)
                encode_line_pcm(sc, tmp, sp, y, 0, rgb_plane_order[c]);
        }
#endif
        return;
    }
#endif

    u32vec4 slice_state_off = (slice_idx*codec_planes +
                               uvec4(0, 1, 1, 2))*plane_state_size;

#ifdef GOLOMB
    slice_state_off >>= 3;
    init_golomb();
#endif

#ifndef RGB
    for (uint c = 0; c < color_planes; c++) {
        int run_index = 0;

        int h = sc.slice_dim.y;
        if (c > 0 && c < 3)
            h = ceil_rshift(h, chroma_shift.y);

        uint p = min(c, planes - 1);
        uint comp = c - p;

        for (int y = 0; y < h; y++)
            encode_line(sc, src[p], slice_state_off[c], sp, y, p,
                        comp, U8(context_model), run_index);
    }
#else
    int run_index = 0;
    for (int y = 0; y < sc.slice_dim.y; y++) {
        preload_rgb(sc, sp, sc.slice_dim.x, y, true);

        for (uint c = 0; c < color_planes; c++)
            encode_line(sc, tmp, slice_state_off[c],
                        sp, y, 0, rgb_plane_order[c],
                        U8(context_model), run_index);
    }
#endif
}

void finalize_slice(in uint slice_idx)
{
#ifdef GOLOMB
    uint32_t enc_len = hdr_len + flush_put_bits(pb);
#else
    uint32_t enc_len = rac_terminate();
#endif

    u8buf bs = u8buf(slice_data + rc.bs_start);

    /* Append slice length */
    u8vec4 enc_len_p = unpack8(enc_len);
    bs[enc_len + 0].v = enc_len_p.z;
    bs[enc_len + 1].v = enc_len_p.y;
    bs[enc_len + 2].v = enc_len_p.x;
    enc_len += 3;

    /* Calculate and write CRC */
    if (has_crc) {
        bs[enc_len].v = uint8_t(0);
        enc_len++;

        uint32_t crc = crcref;
        for (int i = 0; i < enc_len; i++)
            crc = crc_ieee[(crc & 0xFF) ^ uint32_t(bs[i].v)] ^ (crc >> 8);

        if (crcref != 0x00000000)
            crc ^= 0x8CD88196;

        u8vec4 crc_p = unpack8(crc);
        bs[enc_len + 0].v = crc_p.x;
        bs[enc_len + 1].v = crc_p.y;
        bs[enc_len + 2].v = crc_p.z;
        bs[enc_len + 3].v = crc_p.w;
        enc_len += 4;
    }

    slice_results[slice_idx] = enc_len;
}

void main(void)
{
    uint slice_idx = gl_WorkGroupID.y*gl_NumWorkGroups.x + gl_WorkGroupID.x;

    if (gl_LocalInvocationID.x == 0)
        rc = slice_ctx[slice_idx].c;
    barrier();

    encode_slice(slice_ctx[slice_idx], slice_idx);

    if (gl_LocalInvocationID.x == 0)
        finalize_slice(slice_idx);
}
