/* GStreamer
 * Copyright (C) 2024 Collabora Ltd
 *  @author: Daniel Morin <daniel.morin@collabora.com>
 *
 * gstanalyticssegmentationmtd.c
 *
 * This library is free software; you can redistribute it and/or
 * modify it under the terms of the GNU Library General Public
 * License as published by the Free Software Foundation; either
 * version 2 of the License, or (at your option) any later version.
 *
 * This library is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * Library General Public License for more details.
 *
 * You should have received a copy of the GNU Library General Public
 * License along with this library; if not, write to the
 * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
 * Boston, MA 02110-1301, USA.
 */
#ifdef HAVE_CONFIG_H
#include "config.h"
#endif

#include "gstanalyticssegmentationmtd.h"
#include <gst/video/video-info.h>

GST_DEBUG_CATEGORY_EXTERN (gst_analytics_relation_meta_debug);
#define GST_CAT_DEFAULT gst_analytics_relation_meta_debug

/**
 * SECTION: gstanalyticssegmentationmtd
 * @title: GstAnalyticsSegmentationMtd
 * @short_description: An analytics metadata for image segmentation inside a
 * #GstAnalyticsRelationMeta
 * @symbols:
 * - GstAnalyticsSegmentationMtd
 *  @see_also: #GstAnalyticsMtd, #GstAnalyticsRelationMeta
 *
 *  This type of metadata holds information on which pixels belongs to
 *  a region of the image representing a type of object.
 *
 *  It supports two types of segmentation, semantic or instance:
 *   * Semantic: All objects of the same type have the same id
 *   * Instance: Each instance of an object has a different id
 *
 *  The results of the segmentation are stored in a #GstBuffer that has a
 *  #GstVideoMeta associated with it. This buffer is stored in the
 *  GstAnalyticsSegmentationMtd using
 *  #gst_analytics_relation_meta_add_segmentation_mtd(). The #GstBuffer
 *  containing the segmentation mask is image-like but the color values are
 *  arbitrary values, referred by region-id in this API, without meaning beyond
 *  specifying that two pixels in the original image with the same values in
 *  their corresponding mask value belong to the same region.
 *
 *  To further describe a region, the #GstAnalyticsSegmentationMtd can be
 *  associated with other #GstAnalyticsMtd. Since region ids are
 *  generated by the segmentation process itself and are not always sequential,
 *  we use a map of indexes to region ids starting with 0 without discontinuity
 *  which facilitate N-to-N mapping with other #GstAnalyticsMtd. For
 *  example it can be associated with #GstAnalyticsClsMtd to describe the class
 *  of object matching the pixels of a segmented region.
 *
 *  Example: Associate Instance Segmentation with Classification
 *
 *  In the following example the segmentation process will fill segmask with
 *  values of 0 for background, 12 for the first region which correspond to a
 *  to a strawberry, 7 for the second region that also correspond to  a
 *  strawberry in the image and 31 for the third region that correspond to a
 *  leaf in the image.
 *  region_ids is fill during segmentation post-processing
 *
 *  region_ids:
 *  |region-index | region-id |
 *  |-------------|-----------|
 *  | 0           | 0         |
 *  | 1           | 12        |
 *  | 2           | 7         |
 *  | 3           | 31        |
 *
 *  region_count = 4
 *
 *  ``` C
 *    GstAnalyticsSegmentationMtd segmtd;
 *    GstAnalyticsClassificationMtd clsmtd;
 *    GstBuffer *segmask, *img;
 *    guint *region_ids;
 *    gsize region_count, class_count;
 *    gfloat *class_confidence;
 *    GQuark *classes;
 *
 *    ... (segmentation filling segmask based on img)
 *
 *    gst_analytics_relation_meta_add_segmentation_mtd (rmeta, segmask,
 *      GST_SEGMENTATION_TYPE_INSTANCE, region_count, region_ids, &segmtd);
 *    class_count = region_count;
 *
 *    ... (class-index must match and correspond to region-index)
 *    classes [0]  = g_quark_from_string ("background");
 *    classes [1]  = g_quark_from_string ("strawberry");
 *    classes [2] = g_quark_from_string ("strawberry");
 *    classes [3] = g_quark_from_string ("leaf");
 *
 *    ... (set confidence level for each class associated with a region
 *    ... where -1.0 mean undefined.)
 *    class_confidence [0] = -1.0;
 *    class_confidence [1] = 0.6;
 *    class_confidence [2] = 0.9;
 *    class_confidence [3] = 0.8;
 *
 *    gst_analytics_relation_meta_add_cls_mtd (rmeta, class_count,
 *      class_confidence, classes, &clsmtd);
 *
 *    gst_analytics_relation_meta_set_relation (rmeta,
 *     GST_ANALYTICS_REL_TYPE_RELATE_TO, segmtd.id, clsmtd.id);
 *  ```
 *
 *  Example: Associate Semantic Segmentation with Classification
 *  Assuming the same context as for Instance Segmentation above but instead
 *  a semantic segmentation is performed, therefore region-id-12 and region-id-7
 *  are now represented by the same region-id-12
 *
 *  region_ids: (here
 *  |region-index | region-id |
 *  |-------------|-----------|
 *  | 0           | 0         |
 *  | 1           | 12        |
 *  | 2           | 31        |
 *
 *  Code remain the same except that we set all confidence level to undefined
 *  (-1.0).
 *
 *  ```
 *    ... (class-index must match and correspond to region-index)
 *    classes [0]  = g_quark_from_string ("background");
 *    classes [1]  = g_quark_from_string ("strawberry");
 *    classes [2] = g_quark_from_string ("leaf");
 *
 *    ... (set confidence level for each class associated with a region
 *    ... where -1.0 mean undefined.)
 *    class_confidence [0] = -1.0;
 *    class_confidence [1] = -1.0;
 *    class_confidence [2] = -1.0;
 *
 *    gst_analytics_relation_meta_add_cls_mtd (rmeta, class_count,
 *      class_confidence, classes, &clsmtd);
 *
 *    gst_analytics_relation_meta_set_relation (rmeta,
 *     GST_ANALYTICS_REL_TYPE_RELATE_TO, segmtd.id, clsmtd.id);
 *  ```
 *
 *  Example: Retrieving class associated with a segmentation region-id-12
 *  This the typical case for an overlay as we visit the segmentation mask we
 *  we find region-id values
 *
 *  ```
 *  gsize idx;
 *  gst_analytics_segmentation_mtd_get_region_index (&segmtd, &idx, 12);
 *  gst_analytics_relation_meta_get_direct_related (rmeta, segmtd.id,
 *    GST_ANALYTICS_REL_TYPE_RELATE_TO, gst_analytics_cls_mtd_get_type (),
 *    NULL, &clsmtd);
 *
 *   GQuark region_class = gst_analytics_cls_mtd_get_quark (&segmtd, idx)
 *   ...
 *  ```
 *
 *  Since: 1.26
 */

static void gst_analytics_segmentation_mtd_clear (GstBuffer * buffer,
    GstAnalyticsMtd * mtd);

static gboolean
gst_analytics_segmentation_mtd_transform (GstBuffer * transbuf,
    GstAnalyticsMtd * transmtd, GstBuffer * buffer, GQuark type, gpointer data);

static const GstAnalyticsMtdImpl segmentation_impl = {
  "segmentation",
  gst_analytics_segmentation_mtd_transform,
  gst_analytics_segmentation_mtd_clear
};

/*
 * GstAnalyticsSegMtdData:
 * @type: #GstSegmentationType indicate if the mask values are object/region-id
 * (in the case of instance segmentation) or object/region-type (in the case
 * of semantic segmentation).
 * @masks: #GstBuffer used to store segmentation masks
 * @region_count: Number of region in the segmentation masks
 * @region_ids: Indexed region ids
 *
 * Store segmentation results where each value represent a group to which
 * belong the corresponding pixel from original image where segmentation was
 * performed. All values equal in @masks form a mask defining all the
 * pixel belonging to the same segmented region from the original image. The
 * GstVideoMeta attached to the @masks, describe masks resolution, padding,
 * format, ... The format in video meta has a special meaning in the context
 * of the mask, GRAY8 mean that @masks value can take 256 values which mean
 * 256 segmented region can be represented.
 *
 */
typedef struct _GstAnalyticsSegMtdData
{
  GstSegmentationType type;
  GstBuffer *masks;

  gint masks_loc_x;
  gint masks_loc_y;
  guint masks_loc_w;
  guint masks_loc_h;

  gsize region_count;
  guint32 region_ids[];         /* Must be last */
} GstAnalyticsSegMtdData;

/**
 * gst_analytics_segmentation_mtd_get_mtd_type:
 *
 * Get an instance of #GstAnalyticsMtdType that represent segmentation
 * metadata type.
 *
 * Returns: A #GstAnalyticsMtdType type
 *
 * Since: 1.26
 */
GstAnalyticsMtdType
gst_analytics_segmentation_mtd_get_mtd_type (void)
{
  return (GstAnalyticsMtdType) & segmentation_impl;
}

/**
 * gst_analytics_segmentation_mtd_get_mask:
 * @handle: Instance
 * @masks_loc_x: (out caller-allocates)(nullable): Left coordinate of the
 * rectangle corresponding to the mask in the image.
 * @masks_loc_y: (out caller-allocates)(nullable): Top coordinate of the
 * rectangle corresponding to the mask in the image.
 * @masks_loc_w: (out caller-allocates)(nullable): Width of the rectangle
 * corresponding to the mask in the image.
 * @masks_loc_h: (out caller-allocates)(nullable): Height of the rectangle
 * corresponding to the mask in the image.
 *
 * Get segmentation mask data.
 *
 * Returns: Segmentation mask data stored in a #GstBuffer
 *
 * Since: 1.26
 */
GstBuffer *
gst_analytics_segmentation_mtd_get_mask (const GstAnalyticsSegmentationMtd *
    handle, gint * masks_loc_x, gint * masks_loc_y, guint * masks_loc_w, guint *
    masks_loc_h)
{
  GstAnalyticsSegMtdData *mtddata;

  g_return_val_if_fail (handle, FALSE);

  mtddata = gst_analytics_relation_meta_get_mtd_data (handle->meta, handle->id);
  g_return_val_if_fail (mtddata != NULL, NULL);

  if (masks_loc_x)
    *masks_loc_x = mtddata->masks_loc_x;
  if (masks_loc_y)
    *masks_loc_y = mtddata->masks_loc_y;
  if (masks_loc_w)
    *masks_loc_w = mtddata->masks_loc_w;
  if (masks_loc_h)
    *masks_loc_h = mtddata->masks_loc_h;

  return gst_buffer_ref (mtddata->masks);
}

/**
 * gst_analytics_segmentation_mtd_get_region_index:
 * @handle: Instance
 * @index: (out caller-allocates)(not nullable): Region index
 * @id: Region id
 *
 * Get region index of the region identified by @id.
 *
 * Returns: TRUE if a region with @id exist, otherwise FALSE
 *
 * Since: 1.26
 */
gboolean
gst_analytics_segmentation_mtd_get_region_index (const
    GstAnalyticsSegmentationMtd * handle, gsize * index, guint id)
{
  GstAnalyticsSegMtdData *mtddata;

  g_return_val_if_fail (handle, FALSE);
  g_return_val_if_fail (index != NULL, FALSE);

  mtddata = gst_analytics_relation_meta_get_mtd_data (handle->meta, handle->id);
  g_return_val_if_fail (mtddata != NULL, FALSE);

  gsize i;
  for (i = 0; i < mtddata->region_count; i++) {
    if (mtddata->region_ids[i] == id) {
      *index = i;
      return TRUE;
    }
  }
  return FALSE;
}

/**
 * gst_analytics_segmentation_mtd_get_region_id:
 * @handle: Instance
 * @index: Region index
 *
 * Get id of the region corresponding to @index, which should be
 * smaller than the return value of
 * gst_analytics_segmentation_mtd_get_region_count()
 *
 * Returns: The region ID
 *
 * Since: 1.26
 */
guint
gst_analytics_segmentation_mtd_get_region_id (const
    GstAnalyticsSegmentationMtd * handle, gsize index)
{
  GstAnalyticsSegMtdData *mtddata;

  g_return_val_if_fail (handle, 0);

  mtddata = gst_analytics_relation_meta_get_mtd_data (handle->meta, handle->id);
  g_return_val_if_fail (mtddata != NULL, 0);
  g_return_val_if_fail (index < mtddata->region_count, 0);

  return mtddata->region_ids[index];
}

/**
 * gst_analytics_segmentation_mtd_get_region_count:
 * @handle: Instance
 *
 * Get the regions count.
 *
 * Returns: Number of regions segmented
 *
 * Since: 1.26
 */
gsize
gst_analytics_segmentation_mtd_get_region_count (const
    GstAnalyticsSegmentationMtd * handle)
{
  GstAnalyticsSegMtdData *mtddata;

  g_return_val_if_fail (handle, FALSE);

  mtddata = gst_analytics_relation_meta_get_mtd_data (handle->meta, handle->id);
  g_return_val_if_fail (mtddata != NULL, FALSE);

  return mtddata->region_count;
}

/**
 * gst_analytics_relation_meta_add_segmentation_mtd:
 * @instance: Instance of #GstAnalyticsRelationMeta where to add segmentation
 * instance.
 * @buffer:(in)(transfer full): Buffer containing segmentation masks. @buffer
 * must have a #GstVideoMeta attached
 * @segmentation_type:(in): Segmentation type
 * @region_count:(in): Number of regions in the masks
 * @region_ids:(in) (array length=region_count): Arrays of region ids present in the mask.
 * @masks_loc_x:(in): Left coordinate of the rectangle corresponding to the masks in the image.
 * @masks_loc_y:(in): Top coordinate of the rectangle corresponding to the masks in the image.
 * @masks_loc_w:(in): Width of the rectangle corresponding to the masks in the image.
 * @masks_loc_h:(in): Height of the rectangle corresponding to the masks in the image.
 * @segmentation_mtd:(out)(not nullable): Handle update with newly added segmentation meta.
 *
 * Add analytics segmentation metadata to @instance. The rectangle (@masks_loc_x,
 * @mask_loc_y, @mask_loc_w, @mask_loc_h) define a area of the image that
 * correspond to the segmentation masks stored in @buffer. For example if the
 * segmentation masks stored in @buffer describe the segmented regions for the
 * entire image the rectangular area will be (@masks_loc_x = 0, @masks_loc_y = 0,
 * @masks_loc_w = image_width, @masks_loc_h = image_height).
 *
 * Returns: TRUE if added successfully, otherwise FALSE
 *
 * Since: 1.26
 */
gboolean
gst_analytics_relation_meta_add_segmentation_mtd (GstAnalyticsRelationMeta *
    instance, GstBuffer * buffer, GstSegmentationType segmentation_type,
    gsize region_count, guint * region_ids, gint masks_loc_x, gint masks_loc_y,
    guint masks_loc_w, guint masks_loc_h, GstAnalyticsSegmentationMtd *
    segmentation_mtd)
{
  const gsize region_ids_size = sizeof (guint) * region_count;
  const gsize size = sizeof (GstAnalyticsSegMtdData) + region_ids_size;

  g_return_val_if_fail (instance != NULL, FALSE);
#ifndef G_DISABLE_CHECKS
  GstVideoMeta *vmeta = gst_buffer_get_video_meta (buffer);
  g_return_val_if_fail (vmeta != NULL, FALSE);
  g_return_val_if_fail (vmeta->format == GST_VIDEO_FORMAT_GRAY8 ||
      vmeta->format == GST_VIDEO_FORMAT_GRAY16_BE ||
      vmeta->format == GST_VIDEO_FORMAT_GRAY16_LE, FALSE);
#endif

  GstAnalyticsSegMtdData *mtddata = NULL;
  mtddata =
      (GstAnalyticsSegMtdData *) gst_analytics_relation_meta_add_mtd (instance,
      &segmentation_impl, size, segmentation_mtd);

  if (mtddata) {
    mtddata->masks = buffer;
    mtddata->type = segmentation_type;
    mtddata->region_count = region_count;
    mtddata->masks_loc_x = masks_loc_x;
    mtddata->masks_loc_y = masks_loc_y;
    mtddata->masks_loc_w = masks_loc_w;
    mtddata->masks_loc_h = masks_loc_h;
    memcpy (mtddata->region_ids, region_ids, region_ids_size);
  }

  return mtddata != NULL;
}

static void
gst_analytics_segmentation_mtd_clear (GstBuffer * buffer, GstAnalyticsMtd * mtd)
{
  GstAnalyticsSegMtdData *segdata;
  segdata = gst_analytics_relation_meta_get_mtd_data (mtd->meta, mtd->id);
  g_return_if_fail (segdata != NULL);
  gst_clear_buffer (&segdata->masks);
}

static gboolean
gst_analytics_segmentation_mtd_transform (GstBuffer * transbuf,
    GstAnalyticsMtd * transmtd, GstBuffer * buffer, GQuark type, gpointer data)
{
  GstAnalyticsSegMtdData *segdata =
      gst_analytics_relation_meta_get_mtd_data (transmtd->meta,
      transmtd->id);

  if (transbuf != buffer)
    gst_buffer_ref (segdata->masks);

  if (GST_VIDEO_META_TRANSFORM_IS_MATRIX (type)) {
    GstVideoMetaTransformMatrix *trans = data;
    GstVideoRectangle rect = { segdata->masks_loc_x, segdata->masks_loc_y,
      segdata->masks_loc_w, segdata->masks_loc_h
    };

    if (trans->matrix[0][1] != 0 || trans->matrix[1][0] != 0 ||
        trans->matrix[0][0] < 0 || trans->matrix[1][1] < 0) {
      GST_WARNING ("Segmentation meta doesn't support rotations or flips,"
          " not copying from buffer %" GST_PTR_FORMAT " to buffer: %"
          GST_PTR_FORMAT, buffer, transbuf);
      return FALSE;
    }

    if (!gst_video_meta_transform_matrix_rectangle (trans, &rect))
      return FALSE;

    segdata->masks_loc_x = rect.x;
    segdata->masks_loc_y = rect.y;
    segdata->masks_loc_w = rect.w;
    segdata->masks_loc_h = rect.h;
  } else if (GST_VIDEO_META_TRANSFORM_IS_SCALE (type)) {
    GstVideoMetaTransform *trans = data;
    gint ow, oh, nw, nh;

    ow = GST_VIDEO_INFO_WIDTH (trans->in_info);
    nw = GST_VIDEO_INFO_WIDTH (trans->out_info);
    oh = GST_VIDEO_INFO_HEIGHT (trans->in_info);
    nh = GST_VIDEO_INFO_HEIGHT (trans->out_info);

    segdata->masks_loc_x *= nw;
    segdata->masks_loc_x /= ow;

    segdata->masks_loc_w *= nw;
    segdata->masks_loc_w /= ow;

    segdata->masks_loc_y *= nh;
    segdata->masks_loc_y /= oh;

    segdata->masks_loc_h *= nh;
    segdata->masks_loc_h /= oh;

  }

  return TRUE;
}

/**
 * gst_analytics_relation_meta_get_segmentation_mtd:
 * @meta: Instance of #GstAnalyticsRelationMeta
 * @an_meta_id: Id of #GstAnalyticsSegmentationMtd instance to retrieve
 * @rlt: (out caller-allocates)(not nullable): Will be filled with relatable
 *    meta
 *
 * Fill @rlt if a analytics-meta with id == @an_meta_id exist in @meta instance,
 * otherwise this method return FALSE and @rlt is invalid.
 *
 * Returns: TRUE if successful.
 *
 * Since: 1.26
 */
gboolean
gst_analytics_relation_meta_get_segmentation_mtd (GstAnalyticsRelationMeta *
    meta, guint an_meta_id, GstAnalyticsSegmentationMtd * rlt)
{
  return gst_analytics_relation_meta_get_mtd (meta, an_meta_id,
      gst_analytics_segmentation_mtd_get_mtd_type (),
      (GstAnalyticsSegmentationMtd *) rlt);
}
