/*
 * GStreamer
 * Copyright (C) 2024 Collabora Ltd.
 *  Authors: Daniel Morin <daniel.morin@collabora.com>
 *           Vineet Suryan <vineet.suryan@collabora.com>
 *           Santosh Mahto <santosh.mahto@collabora.com>
 *
 * This library is free software; you can redistribute it and/or
 * modify it under the terms of the GNU Library General Public
 * License as published by the Free Software Foundation; either
 * version 2 of the License, or (at your option) any later version.
 *
 * This library is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * Library General Public License for more details.
 *
 * You should have received a copy of the GNU Library General Public
 * License along with this library; if not, write to the
 * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
 * Boston, MA 02110-1301, USA.
 */

/**
 * SECTION:element-yolosegv8tensordec
 * @short_description: Decode tensors from a FastSAM/YOLOv8 segmentation
 * models
 *
 * This element can parse per-buffer inference tensors meta data generated by an upstream
 * inference element
 *
 * ## Example launch command:
 *
 * Test image file, model file and labels file can be found here :
 * https://gitlab.collabora.com/gstreamer/onnx-models
 *
 * gst-launch-1.0 v4l2src device=/dev/video4 ! videorate max-rate=3 \
 *  ! videoconvertscale ! video/x-raw, pixel-aspect-ratio=1/1 \
 *  ! onnxinference \
 *    model-file=/home/dmorin/repos/onnx-models/models/yolov8s-seg.onnx \
 *  ! yolosegv8tensordec class-confidence-threshold=0.8 iou-threshold=0.3 \
 *    max-detections=100 \
 *    label-file=/home/dmorin/repos/onnx-models/labels/COCO_classes.txt \
 *  ! segmentationoverlay \
 *  ! glimagesink sink="gtkglsink processing-deadline=300000000
 *
 * The original repository of the Yolo is located at
 * https://github.com/ultralytics/ultralytics.
 * For easy experimentation, a  object segmentation model based on Yolo
 * architecture in Onnx format can be found at https://col.la/gstonnxmodelseg.
 * This model already has the required tensor-ids embedded in the model
 * It's also possible to embed tensor-ids into any model based on Yolo
 * architecture to allow this tensor-decoder to decode tensors. This process
 * is described in the Readme of this repository: https://col.la/gstonnxmodels"
 *
 * Since: 1.28
 */

#ifdef HAVE_CONFI_H
#include "config.h"
#endif

#include "gstyolosegtensordecoder.h"

#include <gst/analytics/analytics.h>
#include <gio/gio.h>

#include <math.h>

#define YOLO_SEGMENTATION_LOGITS "yolo-v8-segmentation-out-protos"
GQuark YOLO_SEGMENTATION_LOGITS_TENSOR_ID;

#define YOLO_SEGMENTATION_DETECTION_MASK "yolo-v8-segmentation-out-detections"
GQuark YOLO_SEGMENTATION_DETECTION_MASK_ID;

/* *INDENT-OFF* */
static GstStaticPadTemplate gst_yolo_seg_tensor_decoder_sink_template =
GST_STATIC_PAD_TEMPLATE ("sink",
    GST_PAD_SINK,
    GST_PAD_ALWAYS,
    GST_STATIC_CAPS ("video/x-raw,"
        "tensors=(structure)["
          "tensorgroups,"
            "yolo-v8-segmentation-out=(/uniquelist){"
            "(GstCaps)["
              "tensor/strided,"
                "tensor-id=yolo-v8-segmentation-out-detections,"
                "dims=(int)<1, [1,max], [1,max]>,"
                "dims-order=(string)col-major,"
                "type=(string)float32"
              "],"
            "(GstCaps)["
              "tensor/strided,"
                "tensor-id=yolo-v8-segmentation-out-protos,"
                "dims=(int)<1, [1,max], [1,max], [1,max]>,"
                "dims-order=(string)col-major,"
                "type=(string)float32"
              "]"
            "}"
        "]"
      ));
/* *INDENT-ON* */


GST_DEBUG_CATEGORY_STATIC (yolo_seg_tensor_decoder_debug);
#define GST_CAT_DEFAULT yolo_seg_tensor_decoder_debug

GST_ELEMENT_REGISTER_DEFINE (yolo_seg_tensor_decoder, "yolosegv8tensordec",
    GST_RANK_SECONDARY, GST_TYPE_YOLO_SEG_TENSOR_DECODER);

/* For debug purpose */
typedef struct _DebugCandidates
{
  gpointer self;
  gsize fields;                 /* Fields count do debug */
  gsize offset;                 /* Fields offset */
  gsize start;                  /* First field index to debug */
} DebugCandidates;

/* GstYoloSegTensorDecoder Prototypes */
static gboolean gst_yolo_seg_tensor_decoder_stop (GstBaseTransform * trans);
static GstFlowReturn gst_yolo_seg_tensor_decoder_transform_ip (GstBaseTransform
    * trans, GstBuffer * buf);

static void gst_yolo_seg_tensor_decoder_object_found (GstYoloTensorDecoder * od,
    GstAnalyticsRelationMeta * rmeta, BBox * bb, gfloat confidence,
    GQuark class_quark, const gfloat * candidate_masks, gsize offset,
    guint count);

G_DEFINE_TYPE (GstYoloSegTensorDecoder, gst_yolo_seg_tensor_decoder,
    GST_TYPE_YOLO_TENSOR_DECODER);

static gboolean
gst_yolo_seg_tensor_decoder_stop (GstBaseTransform * trans)
{
  GstYoloSegTensorDecoder *self = GST_YOLO_SEG_TENSOR_DECODER (trans);

  self->mask_w = 0;
  self->mask_h = 0;
  self->mask_length = 0;
  if (self->mask_pool)
    gst_buffer_pool_set_active (self->mask_pool, FALSE);
  g_clear_object (&self->mask_pool);

  return TRUE;
}

static void
gst_yolo_seg_tensor_decoder_class_init (GstYoloSegTensorDecoderClass * klass)
{
  GstElementClass *element_class = (GstElementClass *) klass;
  GstBaseTransformClass *basetransform_class = (GstBaseTransformClass *) klass;
  GstYoloTensorDecoderClass *od_class = (GstYoloTensorDecoderClass *) klass;

  /* Define GstYoloSegTensorDecoder debug category. */
  GST_DEBUG_CATEGORY_INIT (yolo_seg_tensor_decoder_debug,
      "yolosegv8tensordec", 0, "Tensor decoder for Yolo segmentation models");

  YOLO_SEGMENTATION_DETECTION_MASK_ID =
      g_quark_from_static_string (YOLO_SEGMENTATION_DETECTION_MASK);
  YOLO_SEGMENTATION_LOGITS_TENSOR_ID =
      g_quark_from_static_string (YOLO_SEGMENTATION_LOGITS);

  gst_element_class_set_static_metadata (element_class,
      "YOLO v8-11 segmentastion tensor decoder", "Tensordecoder/Video",
      "Decode tensors output from the inference of Yolo or FastSAM model (segmentation)"
      " on video frames. It works with YOLO version > 8 and FastSAM models.",
      "Daniel Morin <daniel.morin@collabora.com>, Santosh Mahto <santosh.mahto@collabora.com>");

  gst_element_class_add_pad_template (element_class,
      gst_static_pad_template_get (&gst_yolo_seg_tensor_decoder_sink_template));


  basetransform_class->transform_ip = gst_yolo_seg_tensor_decoder_transform_ip;
  basetransform_class->stop = gst_yolo_seg_tensor_decoder_stop;

  od_class->object_found = gst_yolo_seg_tensor_decoder_object_found;

  /* Workaround hotdoc bug */
  gst_type_mark_as_plugin_api (GST_TYPE_YOLO_TENSOR_DECODER, 0);
}

static void
gst_yolo_seg_tensor_decoder_init (GstYoloSegTensorDecoder * self)
{
  /* GstYoloSegTensorDecoder instance initialization */
  self->mask_w = 0;
  self->mask_h = 0;
  self->mask_length = 0;
  self->mask_pool = NULL;
  memset (&self->mask_roi, 0, sizeof (BBox));

  gst_base_transform_set_passthrough (GST_BASE_TRANSFORM (self), FALSE);
}

static gboolean
gst_yolo_seg_tensor_decoder_get_tensors (GstYoloSegTensorDecoder * self,
    GstBuffer * buf, const GstTensor ** logits_tensor,
    const GstTensor ** detections_tensor)
{
  GstMeta *meta = NULL;
  gpointer iter_state = NULL;

  if (!gst_buffer_get_meta (buf, GST_TENSOR_META_API_TYPE)) {
    GST_WARNING_OBJECT (self,
        "missing tensor meta from buffer %" GST_PTR_FORMAT, buf);
    return FALSE;
  }

  while ((meta = gst_buffer_iterate_meta_filtered (buf, &iter_state,
              GST_TENSOR_META_API_TYPE))) {
    GstTensorMeta *tmeta = (GstTensorMeta *) meta;
    const gsize YOLO_LOGITS_TENSOR_N_DIMS = 4;
    static const gsize logits_dims[4] = { 1, G_MAXSIZE, G_MAXSIZE, G_MAXSIZE };
    const gsize YOLO_DETECTIONS_TENSOR_N_DIMS = 3;
    static const gsize detections_dims[3] = { 1, G_MAXSIZE, G_MAXSIZE };

    *logits_tensor = gst_tensor_meta_get_typed_tensor (tmeta,
        YOLO_SEGMENTATION_LOGITS_TENSOR_ID, GST_TENSOR_DATA_TYPE_FLOAT32,
        GST_TENSOR_DIM_ORDER_COL_MAJOR, YOLO_LOGITS_TENSOR_N_DIMS, logits_dims);
    if (*logits_tensor == NULL)
      continue;


    *detections_tensor = gst_tensor_meta_get_typed_tensor (tmeta,
        YOLO_SEGMENTATION_DETECTION_MASK_ID, GST_TENSOR_DATA_TYPE_FLOAT32,
        GST_TENSOR_DIM_ORDER_COL_MAJOR, YOLO_DETECTIONS_TENSOR_N_DIMS,
        detections_dims);

    if (*detections_tensor == NULL)
      continue;

    guint num_masks = (*logits_tensor)->dims[1];

    if ((*detections_tensor)->dims[1] < 4 + 1 + num_masks) {
      GST_WARNING_OBJECT (self, "Ignore tensor because dims[1] is %zu < %d",
          (*detections_tensor)->dims[1], 4 + 1 + num_masks);
      continue;
    }

    return TRUE;
  }

  return FALSE;
}

/* gst_yolo_seg_tensor_decoder_transform_ip:
 * @trans: Instance
 * @buf:inout: Buffer containing media and where tensors can be attached
 * @return: Flow errors
 * Decode Yolo tensors, post-process tensors and store decoded information
 * into an analytics-meta that is attached to the buffer before been pushed
 * downstream.
 */
static GstFlowReturn
gst_yolo_seg_tensor_decoder_transform_ip (GstBaseTransform * trans,
    GstBuffer * buf)
{
  GstYoloSegTensorDecoder *self = GST_YOLO_SEG_TENSOR_DECODER (trans);
  GstYoloTensorDecoder *od = GST_YOLO_TENSOR_DECODER (trans);
  GstAnalyticsRelationMeta *rmeta;
  gsize mask_w, mask_h;
  const GstTensor *detections_tensor;
  const GstTensor *logits_tensor;
  GstFlowReturn ret = GST_FLOW_OK;
  gboolean rv;

  if (!gst_yolo_seg_tensor_decoder_get_tensors (self, buf, &logits_tensor,
          &detections_tensor)) {
    GST_WARNING_OBJECT (self,
        "Couldn't find logit or detections tensor, skipping");
    return GST_FLOW_OK;
  }

  rmeta = gst_buffer_add_analytics_relation_meta (buf);
  if (rmeta == NULL) {
    GST_ELEMENT_ERROR (trans, STREAM, FAILED, (NULL),
        ("Analytics Relation meta allocation failed"));
    return GST_FLOW_ERROR;
  }

  mask_w = logits_tensor->dims[2];
  mask_h = logits_tensor->dims[3];

  /* The detections need to be cropped to fit the SAR of the image. */
  /* TODO: We're reconstructing the transformation that was done on the
   * original image based on the assumption that the complete image without
   * deformation would be analyzed. This assumption is not alway true and
   * we should try to find a way to convey this transformation information
   * and retrieve from here to know the transformation that need to be done
   * on the mask.*/

  if (self->mask_w != mask_w || self->mask_h != mask_h) {
    self->mask_w = mask_w;
    self->mask_h = mask_h;
    self->mask_length = mask_w * mask_h;

    if (od->video_info.width > od->video_info.height) {
      self->bb2mask_gain = ((gfloat) self->mask_w) / od->video_info.width;
      self->mask_roi.x = 0;
      self->mask_roi.w = self->mask_w;
      self->mask_roi.h = ((gfloat) self->bb2mask_gain) * od->video_info.height;
      self->mask_roi.y = (self->mask_h - self->mask_roi.h) / 2;
    } else {
      self->bb2mask_gain = ((gfloat) self->mask_h) / od->video_info.height;
      self->mask_roi.y = 0;
      self->mask_roi.h = self->mask_h;
      self->mask_roi.w = self->bb2mask_gain * od->video_info.width;
      self->mask_roi.x = (self->mask_w - self->mask_roi.w) / 2;
    }

    if (self->mask_pool) {
      gst_buffer_pool_set_active (self->mask_pool, FALSE);
      g_clear_object (&self->mask_pool);
    }
  }

  if (self->mask_pool == NULL) {
    GstVideoInfo minfo;
    GstCaps *caps;
    gst_video_info_init (&minfo);
    gst_video_info_set_format (&minfo, GST_VIDEO_FORMAT_GRAY8, self->mask_w,
        self->mask_h);
    caps = gst_video_info_to_caps (&minfo);;
    self->mask_pool = gst_video_buffer_pool_new ();

    GstStructure *config = gst_buffer_pool_get_config (self->mask_pool);
    gst_buffer_pool_config_set_params (config, caps, self->mask_length, 0, 0);
    gst_buffer_pool_config_add_option (config,
        GST_BUFFER_POOL_OPTION_VIDEO_META);
    gst_buffer_pool_set_config (self->mask_pool, config);
    gst_buffer_pool_set_active (self->mask_pool, TRUE);
    gst_caps_unref (caps);
  }

  /* Retrieve memory at index 0 from logits_tensor in READ mode */
  rv = gst_buffer_map (logits_tensor->data, &self->map_info_logits,
      GST_MAP_READ);
  if (!rv) {
    GST_ELEMENT_ERROR (self, STREAM, FAILED, (NULL),
        ("Couldn't map logits tensor buffer: %" GST_PTR_FORMAT,
            logits_tensor->data));
    return GST_FLOW_ERROR;
  }

  self->logits_tensor = logits_tensor;

  if (!gst_yolo_tensor_decoder_decode_f32 (od, rmeta, detections_tensor,
          logits_tensor->dims[1]))
    ret = GST_FLOW_ERROR;

  gst_buffer_unmap (logits_tensor->data, &self->map_info_logits);

  return ret;
}

static float
sigmoid (float x)
{
  /* Check for positive overflow */
  if (x > 0) {
    double exp_neg_x = exp (-x);
    return 1.0 / (1.0 + exp_neg_x);
  }
  /* Check for negative overflow and improve stability for negative x */
  else {
    double exp_x = exp (x);
    return exp_x / (1.0 + exp_x);
  }
}

static void
gst_yolo_seg_tensor_decoder_object_found (GstYoloTensorDecoder * od,
    GstAnalyticsRelationMeta * rmeta, BBox * bb, gfloat confidence,
    GQuark class_quark, const gfloat * candidate_masks, gsize offset,
    guint count)
{
  GstYoloSegTensorDecoder *self = GST_YOLO_SEG_TENSOR_DECODER (od);
  GstAnalyticsODMtd od_mtd;
  GstBuffer *mask_buf = NULL;
  gfloat *data_logits = (gfloat *) self->map_info_logits.data;
  BBox bb_mask;
  GstFlowReturn flowret;
  GstMapInfo out_mask_info;
  guint region_ids[2] = { 0, count };
  GstAnalyticsMtd seg_mtd;

  gst_analytics_relation_meta_add_od_mtd (rmeta, class_quark,
      bb->x, bb->y, bb->w, bb->h, confidence, &od_mtd);

  bb_mask.x = self->bb2mask_gain * bb->x + self->mask_roi.x;
  bb_mask.y = self->bb2mask_gain * bb->y + self->mask_roi.y;
  bb_mask.w = self->bb2mask_gain * bb->w;
  bb_mask.h = self->bb2mask_gain * bb->h;

  flowret = gst_buffer_pool_acquire_buffer (self->mask_pool, &mask_buf, NULL);
  g_assert (flowret == GST_FLOW_OK);
  gst_buffer_map (mask_buf, &out_mask_info, GST_MAP_READWRITE);

  GstVideoMeta *vmeta = gst_buffer_get_video_meta (mask_buf);
  g_assert (vmeta != NULL);
  vmeta->width = bb_mask.w;
  vmeta->height = bb_mask.h;

#define MX_MAX (bb_mask.x + bb_mask.w)
#define MY_MAX (bb_mask.y + bb_mask.h)

  for (gint my = bb_mask.y, i = 0; my < MY_MAX; my++) {
    for (gint mx = bb_mask.x; mx < MX_MAX; mx++, i++) {
      float sum = 0.0f;
      gint j = my * self->mask_w + mx;
      for (gsize k = 0; k < self->logits_tensor->dims[1]; ++k) {
        GST_TRACE_OBJECT (self, "protos data at ((mx=%d,my=%d)=%d, %zu) is %f",
            mx, my, j, k, data_logits[k * self->mask_length + j]);
        sum += candidate_masks[offset * k] *
            data_logits[k * self->mask_length + j];
      }
      out_mask_info.data[i] = sigmoid (sum) > 0.5 ? count : 0;
    }
  }

  gst_analytics_relation_meta_add_segmentation_mtd (rmeta, mask_buf,
      GST_SEGMENTATION_TYPE_INSTANCE, 1, region_ids, bb->x, bb->y, bb->w,
      bb->h, &seg_mtd);

  gst_analytics_relation_meta_set_relation (rmeta,
      GST_ANALYTICS_REL_TYPE_RELATE_TO, od_mtd.id, seg_mtd.id);
  gst_analytics_relation_meta_set_relation (rmeta,
      GST_ANALYTICS_REL_TYPE_RELATE_TO, seg_mtd.id, od_mtd.id);


  gst_buffer_unmap (mask_buf, &out_mask_info);
}
