/*
 * GStreamer gstreamer-classifiertensordecoder
 * Copyright (C) 2025 Collabora Ltd.
 *  @author: Daniel Morin <daniel.morin@dmohub.org>
 *
 * gstclassifiertensordecoder.c
 *
 * This library is free software; you can redistribute it and/or
 * modify it under the terms of the GNU Library General Public
 * License as published by the Free Software Foundation; either
 * version 2 of the License, or (at your option) any later version.
 *
 * This library is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * Library General Public License for more details.
 *
 * You should have received a copy of the GNU Library General Public
 * License along with this library; if not, write to the
 * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
 * Boston, MA 02110-1301, USA.
 */

/**
 * SECTION:element-classifiertensordecoder.c
 * @short_description: Decode tensors from classification model using a common
 * tensor output format.
 *
 *
 * This element can parse per-buffer inference tensor meta data generated by
 * an upstream inference element.
 *
 * Tensor format must be:
 *   Dims: [batch-size, class_count]
 *   Datatype: float32
 *
 *   Tensor [M,N]
 *   Batch 0 | Class 0 confidence level | ... | Class N confidence level |
 *   ...
 *   Batch M | Class 0 confidence level | ... | Class N confidence level |
 *
 *   In-memory tensor format:
 *
 *   |Batch 0, Class 0 confidence level |
 *   |Batch 0,           ...            |
 *   |Batch 0, Class N confidence level |
 *   |               ...                |
 *   |Batch M, Class 0 confidence level |
 *   |Batch M,           ...            |
 *   |Batch M, Class N confidence level |
 *
 *
 * ## Example launch command:
 * |[
 * gst-launch-1.0 filesrc location=/onnx-models/images/bus.jpg                 \
 *  ! jpegdec                                                                  \
 *  ! videoconvertscale add-borders=1                                          \
 *  ! onnxinference execution-provider=cpu                                     \
 *    model-file=/onnx-models/models/mobilenet_v1.onnx                         \
 *  ! classifiertensordecoder labels-file=labels.txt ! fakesink               \
 * ]| This pipeline create an tensor-decoder for classification model
 *
 */

#ifdef HAVE_CONFIG_H
#include "config.h"
#endif

#include "gstclassifiertensordecoder.h"
#include <gst/gst.h>
#include <math.h>
#include <gst/analytics/analytics.h>

#define GROUP_ID_CLASSIFICATION "classification-generic-out"
#define GROUP_ID_CLASSIFICATION_SOFTMAXED "classification-generic-softmaxed-out"
#define GST_MODEL_STD_IMAGE_CLASSIFICATION "classification-generic-out"
#define GST_MODEL_STD_IMAGE_CLASSIFICATION_SOFTMAXED "classification-generic-softmaxed-out"

GST_DEBUG_CATEGORY_STATIC (classifier_tensor_decoder_debug);
#define GST_CAT_DEFAULT classifier_tensor_decoder_debug
#define gst_classifier_tensor_decoder_parent_class parent_class

GST_ELEMENT_REGISTER_DEFINE (classifier_tensor_decoder,
    "classifiertensordecoder", GST_RANK_SECONDARY,
    GST_TYPE_CLASSIFIER_TENSOR_DECODER);


/* GstClassifierTensorDecoder properties */
enum
{
  PROP_0,
  PROP_THRESHOLD,
  PROP_LABEL_FILE
};

static const float DEFAULT_THRESHOLD = 0.7f;

static GstStaticPadTemplate gst_classifier_tensor_decoder_src_template =
GST_STATIC_PAD_TEMPLATE ("src",
    GST_PAD_SRC,
    GST_PAD_ALWAYS,
    GST_STATIC_CAPS_ANY);

/* *INDENT-OFF* */

static GstStaticPadTemplate gst_classifier_tensor_decoder_sink_template =
GST_STATIC_PAD_TEMPLATE ("sink",
    GST_PAD_SINK,
    GST_PAD_ALWAYS,
    GST_STATIC_CAPS (
      "video/x-raw,"
        "tensors=(structure)["
          "tensorgroups,"
              GROUP_ID_CLASSIFICATION"=(/uniquelist){"
                 "(GstCaps)["
                    "tensor/strided,"
                      "tensor-id="GST_MODEL_STD_IMAGE_CLASSIFICATION","
                      "dims=<(int)[0,1], (int)[1,max]>,"
                      "dims-order=(string)row-major,"
                      "type={float32, uint8};"
                    "tensor/strided,"
                      "tensor-id="GST_MODEL_STD_IMAGE_CLASSIFICATION","
                      "dims=<(int)[1,max]>,"
                      "dims-order=(string)row-major,"
                      "type={float32, uint8};]"
              "}"
        "];"
      "video/x-raw,"
        "tensors=(structure)["
          "tensorgroups,"
              GROUP_ID_CLASSIFICATION_SOFTMAXED"=(/uniquelist){"
                 "(GstCaps)["
                    "tensor/strided,"
                      "tensor-id="GST_MODEL_STD_IMAGE_CLASSIFICATION_SOFTMAXED","
                      "dims=<(int)[0,1], (int)[1,max]>,"
                      "dims-order=(string)row-major,"
                      "type={float32, uint8};"
                    "tensor/strided,"
                      "tensor-id="GST_MODEL_STD_IMAGE_CLASSIFICATION_SOFTMAXED","
                      "dims=<(int)[1,max]>,"
                      "dims-order=(string)row-major,"
                      "type={float32, uint8};]"
              "}"
        "]"
    ));
/* *INDENT-ON* */

static void gst_classifier_tensor_decoder_set_property (GObject * object,
    guint prop_id, const GValue * value, GParamSpec * pspec);
static void gst_classifier_tensor_decoder_get_property (GObject * object,
    guint prop_id, GValue * value, GParamSpec * pspec);

static void gst_classifier_tensor_decoder_finalize (GObject * object);

static GstFlowReturn
gst_classifier_tensor_decoder_transform_ip (GstBaseTransform * trans,
    GstBuffer * buf);

static GstStateChangeReturn
gst_classifier_tensor_decoder_change_state (GstElement * element,
    GstStateChange transition);

static gboolean
gst_classifier_tensor_decoder_set_caps (GstBaseTransform * trans,
    GstCaps * incaps, GstCaps * outcaps);


#define softmax(len, values, results, max_val)                                \
  gsize i;                                                                    \
  gfloat sum = 0.0;                                                           \
  gfloat value;                                                               \
  g_return_if_fail (values != NULL);                                          \
  g_return_if_fail (results != NULL);                                          \
                                                                              \
  /* Calculate exponential of every value */                                  \
  for (i = 0; i < len; i++) {                                                 \
    value = values[i] / max_val;                                              \
    results[i] = exp (value);                                                  \
    sum += results[i];                                                         \
  }                                                                           \
                                                                              \
  /* Complete softmax */                                                      \
  for (i = 0; i < len; i++) {                                                 \
    result[i] = results[i] / sum;                                              \
  }

static void
softmax_u8 (gsize len, const guint8 * values, gfloat * result)
{
  softmax (len, values, result, 255.0);
}

static void
softmax_f32 (gsize len, const gfloat * values, gfloat * result)
{
  softmax (len, values, result, 1.0);
}

G_DEFINE_TYPE (GstClassifierTensorDecoder, gst_classifier_tensor_decoder,
    GST_TYPE_BASE_TRANSFORM);

static void
gst_classifier_tensor_decoder_class_init (GstClassifierTensorDecoderClass *
    klass)
{
  GObjectClass *gobject_class = (GObjectClass *) klass;
  GstElementClass *element_class = (GstElementClass *) klass;
  GstBaseTransformClass *basetransform_class = (GstBaseTransformClass *) klass;

  GST_DEBUG_CATEGORY_INIT (classifier_tensor_decoder_debug,
      "classifiertensordecoder", 0,
      "Tensor decoder for classification model with common output format");

  gobject_class->set_property = gst_classifier_tensor_decoder_set_property;
  gobject_class->get_property = gst_classifier_tensor_decoder_get_property;
  gobject_class->finalize = gst_classifier_tensor_decoder_finalize;

  g_object_class_install_property (G_OBJECT_CLASS (klass),
      PROP_THRESHOLD,
      g_param_spec_float ("class-confidence-threshold",
          "Class confidence threshold",
          "Classes with a confidence level inferior to this threshold "
          "will be excluded",
          0.0, 1.0, DEFAULT_THRESHOLD,
          (GParamFlags) (G_PARAM_READWRITE | G_PARAM_STATIC_STRINGS)));

  g_object_class_install_property (G_OBJECT_CLASS (klass),
      PROP_LABEL_FILE,
      g_param_spec_string ("labels-file",
          "Class labels file",
          "Path to a file containing class label. COCO format",
          NULL, (G_PARAM_READWRITE | G_PARAM_STATIC_STRINGS)));

  element_class->change_state = gst_classifier_tensor_decoder_change_state;

  gst_element_class_set_static_metadata (element_class,
      "Classification tensor decoder", "Tensordecoder",
      "Decode tensors output from classification model using common format.\n"
      "\tTensor format must be: \n" "\t\tDims: [batch-size, class_count]\n"
      "\t\tDatatype: float32 \n" "\n" "\t\tTensor [M,N]\n"
      "\t\t\tBatch 0   | Class 0 confidence level | ... | Class N-1 confidence level |\n"
      "\t\t\t...\n"
      "\t\t\tBatch M-1 | Class 0 confidence level | ... | Class N-1 confidence level |\n"
      "\t\t\n" "\tIn-memory tensor format:\n" "\n"
      "\t\t|Batch 0, Class 0 confidence level     |\n"
      "\t\t|Batch 0,           ...                |\n"
      "\t\t|Batch 0, Class N-1 confidence level   |\n"
      "\t\t|               ...                    |\n"
      "\t\t|Batch M-1, Class 0 confidence level   |\n"
      "\t\t|Batch M-1,           ...              |\n"
      "\t\t|Batch M-1, Class N-1 confidence level |\n" "\n" " model",
      "Daniel Morin <daniel.morin@collabora.com>");

  gst_element_class_add_pad_template (element_class,
      gst_static_pad_template_get
      (&gst_classifier_tensor_decoder_sink_template));

  gst_element_class_add_pad_template (element_class,
      gst_static_pad_template_get
      (&gst_classifier_tensor_decoder_src_template));

  basetransform_class->transform_ip =
      GST_DEBUG_FUNCPTR (gst_classifier_tensor_decoder_transform_ip);

  basetransform_class->set_caps =
      GST_DEBUG_FUNCPTR (gst_classifier_tensor_decoder_set_caps);
}

static void
gst_classifier_tensor_decoder_init (GstClassifierTensorDecoder * self)
{
  self->threshold = DEFAULT_THRESHOLD;
  self->labels_file = NULL;
  self->postproc_result = NULL;
  self->class_count = 0;
  self->do_softmax = TRUE;

  gst_base_transform_set_passthrough (GST_BASE_TRANSFORM (self), FALSE);
  GST_PAD_UNSET_ACCEPT_INTERSECT (self->basetransform.sinkpad);
}

static void
gst_classifier_tensor_decoder_finalize (GObject * object)
{
  GstClassifierTensorDecoder *self = GST_CLASSIFIER_TENSOR_DECODER (object);

  g_free (self->labels_file);
  G_OBJECT_CLASS (gst_classifier_tensor_decoder_parent_class)->finalize
      (object);
}

static void
gst_classifier_tensor_decoder_set_property (GObject * object, guint prop_id,
    const GValue * value, GParamSpec * pspec)
{
  GstClassifierTensorDecoder *self = GST_CLASSIFIER_TENSOR_DECODER (object);
  static GFileTest filetest = (G_FILE_TEST_EXISTS | G_FILE_TEST_IS_REGULAR);

  switch (prop_id) {
    case PROP_THRESHOLD:
      self->threshold = g_value_get_float (value);
      break;
    case PROP_LABEL_FILE:
      self->labels_file = g_strdup (g_value_get_string (value));

      if (self->labels_file) {
        if (!g_file_test (self->labels_file, filetest)) {
          GST_ERROR_OBJECT (self, "Unable to load %s", self->labels_file);
          g_free (g_steal_pointer (&self->labels_file));
        }
      } else {
        GST_ERROR_OBJECT (self, "Invalid file");
      }
      break;
    default:
      G_OBJECT_WARN_INVALID_PROPERTY_ID (object, prop_id, pspec);
      break;
  }
}

static void
gst_classifier_tensor_decoder_get_property (GObject * object, guint prop_id,
    GValue * value, GParamSpec * pspec)
{
  GstClassifierTensorDecoder *self = GST_CLASSIFIER_TENSOR_DECODER (object);

  switch (prop_id) {
    case PROP_THRESHOLD:
      g_value_set_float (value, self->threshold);
      break;
    case PROP_LABEL_FILE:
      g_value_set_string (value, self->labels_file);
      break;
    default:
      G_OBJECT_WARN_INVALID_PROPERTY_ID (object, prop_id, pspec);
      break;
  }
}

static guint
gst_classifier_tensor_decoder_load_labels (GstClassifierTensorDecoder * self)
{
  gchar *content = NULL;
  gchar **tokens = NULL;
  gsize len;
  GError *err = NULL;
  GQuark val;
  GArray *class_quark = NULL;

  if (self->labels_file == NULL) {
    GST_ERROR_OBJECT (self, "Missing label file");
    return 0;
  }
  if (!g_file_get_contents (self->labels_file, &content, &len, &err)) {
    GST_ERROR_OBJECT (self, "Could not load labels file %s: %s",
        self->labels_file, err->message);
    g_error_free (err);
    return 0;
  }

  if (len == 0) {
    GST_ERROR_OBJECT (self, "Labels file %s is empty", self->labels_file);
    g_free (content);
    return 0;
  }

  tokens = g_strsplit (content, "\n", 0);
  g_free (content);

  if (tokens[0] != NULL) {
    class_quark =
        g_array_sized_new (FALSE, FALSE, sizeof (GQuark), self->class_count);
  }

  self->class_quark = g_array_new (FALSE, FALSE, sizeof (GQuark));

  for (int i = 0; tokens[i] != NULL && tokens[i][0] != '\0'; i++) {
    val = g_quark_from_string (tokens[i]);
    g_array_append_val (class_quark, val);
  }

  if (class_quark == NULL)
    GST_WARNING_OBJECT (self, "Label %s file does not contain any labels",
        self->labels_file);

  self->class_quark = class_quark;

  g_strfreev (tokens);
  return self->class_quark->len;
}

static GstStateChangeReturn
gst_classifier_tensor_decoder_change_state (GstElement * element,
    GstStateChange transition)
{
  GstClassifierTensorDecoder *self = GST_CLASSIFIER_TENSOR_DECODER (element);
  GstStateChangeReturn ret;

  switch (transition) {
    case GST_STATE_CHANGE_NULL_TO_READY:
      if (self->labels_file != NULL &&
          !gst_classifier_tensor_decoder_load_labels (self)) {
        return GST_STATE_CHANGE_FAILURE;
      }
      break;
    default:
      break;
  }

  ret = GST_ELEMENT_CLASS (parent_class)->change_state (element, transition);

  switch (transition) {
    case GST_STATE_CHANGE_READY_TO_NULL:
      if (self->class_quark)
        g_array_free (self->class_quark, FALSE);
      if (self->postproc_result)
        g_array_free (self->postproc_result, TRUE);
      break;
    default:
      break;
  }

  return ret;
}

static const GstTensor *
get_tensor (GstTensorMeta * tmeta, GQuark tensor_id)
{
  const GstTensor *tensor;
  const gsize DIMS[] = { 1, G_MAXSIZE };

  tensor = gst_tensor_meta_get_typed_tensor (tmeta, tensor_id,
      GST_TENSOR_DATA_TYPE_FLOAT32, GST_TENSOR_DIM_ORDER_ROW_MAJOR, 1, NULL);
  if (tensor == NULL)
    tensor = gst_tensor_meta_get_typed_tensor (tmeta, tensor_id,
        GST_TENSOR_DATA_TYPE_FLOAT32, GST_TENSOR_DIM_ORDER_ROW_MAJOR, 2, DIMS);
  if (tensor == NULL)
    tensor = gst_tensor_meta_get_typed_tensor (tmeta, tensor_id,
        GST_TENSOR_DATA_TYPE_UINT8, GST_TENSOR_DIM_ORDER_ROW_MAJOR, 1, NULL);
  if (tensor == NULL)
    tensor = gst_tensor_meta_get_typed_tensor (tmeta, tensor_id,
        GST_TENSOR_DATA_TYPE_UINT8, GST_TENSOR_DIM_ORDER_ROW_MAJOR, 2, DIMS);

  return tensor;
}

static const GstTensor *
gst_classifier_tensor_decoder_get_tensor (GstClassifierTensorDecoder *
    self, GstBuffer * buf)
{
  GstMeta *meta = NULL;
  gpointer iter_state = NULL;
  const gchar *expected_tensor_id;

  if (!gst_buffer_get_meta (buf, GST_TENSOR_META_API_TYPE)) {
    GST_DEBUG_OBJECT (self,
        "missing tensor meta from buffer %" GST_PTR_FORMAT, buf);
    return NULL;
  }

  /* Use the tensor-id that matches what was negotiated */
  expected_tensor_id = self->do_softmax ?
      GST_MODEL_STD_IMAGE_CLASSIFICATION :
      GST_MODEL_STD_IMAGE_CLASSIFICATION_SOFTMAXED;

  while ((meta = gst_buffer_iterate_meta_filtered (buf, &iter_state,
              GST_TENSOR_META_API_TYPE))) {
    GstTensorMeta *tensor_meta = (GstTensorMeta *) meta;
    const GstTensor *tensor;

    tensor = get_tensor (tensor_meta,
        g_quark_from_static_string (expected_tensor_id));

    if (tensor)
      return tensor;
  }

  return NULL;
}

static GstFlowReturn
gst_classifier_tensor_decoder_decode (GstClassifierTensorDecoder * self,
    const GstTensor * tensor, GstAnalyticsRelationMeta * rmeta)
{
  GstMapInfo map_info = GST_MAP_INFO_INIT;
  gfloat max = 0.0;
  gfloat *result_data = NULL;
  gsize len;
  GQuark q, qmax = 0;
  gint max_idx = -1;
  GstAnalyticsClsMtd cls_mtd;

  len = tensor->dims[tensor->num_dims - 1];

  if (len != self->class_quark->len) {
    GST_WARNING_OBJECT (self, "Labels file has size %zu, but the tensor has"
        " %u entries, it is probably not the right labels file",
        len, self->class_quark->len);
    len = MIN (len, self->class_quark->len);
  }

  if (!gst_buffer_map (tensor->data, &map_info, GST_MAP_READ)) {
    GST_ELEMENT_ERROR (self, STREAM, FAILED, (NULL),
        ("Failed to map tensor data"));
    return GST_FLOW_ERROR;
  }

  GST_TRACE_OBJECT (self, "Tensor shape dims %zu", tensor->num_dims);

  if (gst_debug_category_get_threshold (GST_CAT_DEFAULT) >= GST_LEVEL_TRACE) {
    for (gint i = 0; i < tensor->num_dims; i++) {
      GST_TRACE_OBJECT (self, "Tensor dim %d: %zu", i, tensor->dims[i]);
    }
  }

  switch (tensor->data_type) {
    case GST_TENSOR_DATA_TYPE_FLOAT32:
      if (map_info.size != len * sizeof (gfloat)) {
        GST_ELEMENT_ERROR (self, STREAM, FAILED, (NULL),
            ("Tensor size is not as expected for float: map.size(%zu) !="
                " label-file-length(%zu) * sizeof(float)(%zu)", map_info.size,
                len, sizeof (float)));
        goto error_mapped;
      }

      if (self->do_softmax) {
        result_data = (gfloat *) self->postproc_result->data;
        softmax_f32 (len, (gfloat *) map_info.data, result_data);
      } else {
        /* Already softmaxed, use data directly */
        result_data = (gfloat *) map_info.data;
      }
      break;
    case GST_TENSOR_DATA_TYPE_UINT8:
      if (map_info.size != len) {
        GST_ELEMENT_ERROR (self, STREAM, FAILED, (NULL),
            ("Tensor size is not as expected for uint8: map.size(%zu) !="
                " label-file-length(%zu))", map_info.size, len));
        goto error_mapped;
      }

      /* Always need conversion buffer for uint8 -> float */
      result_data = (gfloat *) self->postproc_result->data;
      if (self->do_softmax) {
        softmax_u8 (len, (guint8 *) map_info.data, result_data);
      } else {
        const guint8 *uint8_data = map_info.data;
        for (gint i = 0; i < len; i++) {
          result_data[i] = uint8_data[i] / 255.0;
        }
      }
      break;
    default:
      GST_ELEMENT_ERROR (self, STREAM, FAILED, (NULL),
          ("Can't handle data type %d", tensor->data_type));
      goto error_mapped;
  }

  for (gint j = 0; j < len; j++) {
    q = g_array_index (self->class_quark, GQuark, j);

    if (result_data[j] > max) {
      max = result_data[j];
      max_idx = j;
      qmax = q;
    }
  }

  gst_buffer_unmap (tensor->data, &map_info);

  if (max_idx != -1) {
    gst_analytics_relation_meta_add_one_cls_mtd (rmeta, max, qmax, &cls_mtd);
    GST_LOG_OBJECT (self, "Max class is %d:%s with %f", max_idx,
        g_quark_to_string (qmax), max);
  }

  return GST_FLOW_OK;

error_mapped:
  gst_buffer_unmap (tensor->data, &map_info);
  return GST_FLOW_ERROR;
}

static GstFlowReturn
gst_classifier_tensor_decoder_transform_ip (GstBaseTransform * trans,
    GstBuffer * buf)
{
  GstClassifierTensorDecoder *self = GST_CLASSIFIER_TENSOR_DECODER (trans);
  const GstTensor *tensor;
  GstAnalyticsRelationMeta *rmeta;

  tensor = gst_classifier_tensor_decoder_get_tensor (self, buf);
  if (tensor == NULL) {
    GST_WARNING_OBJECT (trans, "missing tensor meta");
    return GST_FLOW_OK;
  }

  rmeta = gst_buffer_add_analytics_relation_meta (buf);

  return gst_classifier_tensor_decoder_decode (self, tensor, rmeta);
}

static gboolean
gst_classifier_tensor_decoder_set_caps (GstBaseTransform * trans,
    GstCaps * incaps, GstCaps * outcaps)
{
  GstClassifierTensorDecoder *self = GST_CLASSIFIER_TENSOR_DECODER (trans);
  const GstCaps *tcaps;
  const GstStructure *s, *ts, *dims_s;
  const GValue *dims_v, *dim_v, *tensors_v, *tensors_gv, *tensor_caps_v;
  gsize dims_size, batchsize = 1;
  gchar buffer[32];
  GQuark val;

  /* Get the classification tensor */
  s = gst_caps_get_structure (incaps, 0);
  g_return_val_if_fail (s != NULL, FALSE);

  tensors_v = gst_structure_get_value (s, "tensors");
  g_return_val_if_fail (tensors_v != NULL, FALSE);

  ts = gst_value_get_structure (tensors_v);
  g_return_val_if_fail (ts != NULL, FALSE);

  /* Try to get classification group (non-softmaxed) first */
  tensors_gv = gst_structure_get_value (ts, GROUP_ID_CLASSIFICATION);
  /* If not found, try softmaxed group */
  if (tensors_gv == NULL)
    tensors_gv =
        gst_structure_get_value (ts, GROUP_ID_CLASSIFICATION_SOFTMAXED);
  g_return_val_if_fail (tensors_gv != NULL, FALSE);

  tensor_caps_v = gst_value_unique_list_get_value (tensors_gv, 0);
  g_return_val_if_fail (tensor_caps_v != NULL, FALSE);

  tcaps = gst_value_get_caps (tensor_caps_v);
  s = gst_caps_get_structure (tcaps, 0);
  g_return_val_if_fail (tcaps != NULL, FALSE);

  if (gst_structure_has_field (s, "tensor-id")) {
    const gchar *tensor_id = gst_structure_get_string (s, "tensor-id");

    /* Determine if we need to apply softmax based on negotiated tensor-id */
    if (g_strcmp0 (tensor_id, GST_MODEL_STD_IMAGE_CLASSIFICATION) == 0) {
      self->do_softmax = TRUE;
    } else if (g_strcmp0 (tensor_id,
            GST_MODEL_STD_IMAGE_CLASSIFICATION_SOFTMAXED) == 0) {
      self->do_softmax = FALSE;
    } else {
      /* Unknown tensor-id, skip */
      return TRUE;
    }

    dims_s = gst_caps_get_structure (tcaps, 0);
    dims_v = gst_structure_get_value (dims_s, "dims");
    dims_size = gst_value_array_get_size (dims_v);

    if (dims_size == 2) {
      /* Explicit batch-size */
      dim_v = gst_value_array_get_value (dims_v, 0);
      batchsize = g_value_get_int (dim_v);

      if (batchsize == 0)
        batchsize = 1;

      dim_v = gst_value_array_get_value (dims_v, 1);
    } else {
      dim_v = gst_value_array_get_value (dims_v, 0);
    }

    /* Get classes count */
    self->class_count = g_value_get_int (dim_v);

    /* Allocate postproc_result buffer for softmax or uint8->float conversion */
    self->postproc_result =
        g_array_sized_new (FALSE, TRUE, sizeof (gfloat), self->class_count);

    if (self->class_quark != NULL &&
        self->class_count != self->class_quark->len) {
      GST_ELEMENT_ERROR (GST_BASE_TRANSFORM (self), STREAM, FAILED,
          ("Label-file/Tensor mismatch"),
          ("Class count from tensor mismatch class count from label file"));
      return FALSE;
    }

    /* Generate labels if no label file was specified. */
    if (self->class_quark == NULL) {
      self->class_quark = g_array_sized_new (FALSE, FALSE, sizeof (GQuark),
          self->class_count);
      for (gsize i = 0; i < self->class_count; i++) {
        if (g_snprintf (buffer, sizeof (buffer), "%zu", i) >= sizeof (buffer)) {
          g_array_free (self->postproc_result, FALSE);
          self->postproc_result = NULL;
          return FALSE;
        }
        val = g_quark_from_string (buffer);
        g_array_append_val (self->class_quark, val);
      }
    }
  }

  return TRUE;
}
