package cn.quantgroup.boot.micrometer.register.kafka;

import static java.util.stream.Collectors.joining;

import io.micrometer.core.instrument.Clock;
import io.micrometer.core.instrument.DistributionSummary;
import io.micrometer.core.instrument.FunctionTimer;
import io.micrometer.core.instrument.LongTaskTimer;
import io.micrometer.core.instrument.Measurement;
import io.micrometer.core.instrument.Meter;
import io.micrometer.core.instrument.Timer;
import io.micrometer.core.instrument.step.StepMeterRegistry;
import io.micrometer.core.instrument.util.DoubleFormat;
import io.micrometer.core.instrument.util.MeterPartition;
import io.micrometer.core.instrument.util.NamedThreadFactory;
import io.micrometer.core.instrument.util.StringUtils;
import java.util.ArrayList;
import java.util.List;
import java.util.Properties;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.TimeUnit;
import java.util.stream.Stream;
import org.apache.kafka.clients.producer.KafkaProducer;
import org.apache.kafka.clients.producer.ProducerConfig;
import org.apache.kafka.clients.producer.ProducerRecord;
import org.springframework.util.ObjectUtils;

public class KafkaMeterRegistry extends StepMeterRegistry {

  private static final ThreadFactory DEFAULT_THREAD_FACTORY = new NamedThreadFactory(
      "kafka-metrics-publisher");
  private final KafkaConfig config;
  private KafkaProducer<String, String> kafkaProducer;

  private final String key;

  public KafkaMeterRegistry(KafkaConfig config, Clock clock) {
    super(config, clock);
    this.config = config;
    if (ObjectUtils.isEmpty(System.getProperty("NAMESPACE"))) {
      key = "default";
    } else {
      key = System.getProperty("NAMESPACE");
    }

    Properties properties = new Properties();
    properties.setProperty(ProducerConfig.BOOTSTRAP_SERVERS_CONFIG, config.services());
    properties.setProperty(ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG,
        "org.apache.kafka.common.serialization.StringSerializer");
    properties.setProperty(ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG,
        "org.apache.kafka.common.serialization.StringSerializer");
    kafkaProducer = new KafkaProducer<>(properties);

    start(DEFAULT_THREAD_FACTORY);
  }

  @Override
  public void start(ThreadFactory threadFactory) {
    super.start(threadFactory);
  }

  @Override
  protected void publish() {

    for (List<Meter> batch : MeterPartition.partition(this, config.batchSize())) {
      batch.stream()
          .flatMap(m -> m.match(
              gauge -> writeGauge(gauge.getId(), gauge.value()),
              counter -> writeCounter(counter.getId(), counter.count()),
              this::writeTimer,
              this::writeSummary,
              this::writeLongTaskTimer,
              gauge -> writeGauge(gauge.getId(), gauge.value(getBaseTimeUnit())),
              counter -> writeCounter(counter.getId(), counter.count()),
              this::writeFunctionTimer,
              this::writeMeter)).forEach(i -> {
            kafkaProducer.send(new ProducerRecord<>(config.topic(), key,i));
          });
    }
  }

  @Override
  protected TimeUnit getBaseTimeUnit() {
    return TimeUnit.MILLISECONDS;
  }

  private Stream<String> writeGauge(Meter.Id id, Double value) {
    if (Double.isFinite(value)) {
      return Stream.of(influxLineProtocol(id, "gauge", Stream.of(new Field("value", value))));
    }
    return Stream.empty();
  }

  private String influxLineProtocol(Meter.Id id, String metricType, Stream<Field> fields) {
    String tags = getConventionTags(id).stream()
        .filter(t -> StringUtils.isNotBlank(t.getValue()))
        .map(t -> "," + t.getKey() + "=" + t.getValue())
        .collect(joining(""));

    return getConventionName(id)
        + tags + ",metric_type=" + metricType + " "
        + fields.map(Field::toString).collect(joining(","))
        + " " + clock.wallTime();
  }

  static class Field {

    final String key;
    final double value;

    Field(String key, double value) {
      // `time` cannot be a field key or tag key
      if (key.equals("time")) {
        throw new IllegalArgumentException("'time' is an invalid field key in InfluxDB");
      }
      this.key = key;
      this.value = value;
    }

    @Override
    public String toString() {
      return key + "=" + DoubleFormat.decimalOrNan(value);
    }
  }

  private Stream<String> writeCounter(Meter.Id id, double count) {
    if (Double.isFinite(count)) {
      return Stream.of(influxLineProtocol(id, "counter", Stream.of(new Field("value", count))));
    }
    return Stream.empty();
  }

  private Stream<String> writeTimer(Timer timer) {
    final Stream<Field> fields = Stream.of(
        new Field("sum", timer.totalTime(getBaseTimeUnit())),
        new Field("count", timer.count()),
        new Field("mean", timer.mean(getBaseTimeUnit())),
        new Field("upper", timer.max(getBaseTimeUnit()))
    );

    return Stream.of(influxLineProtocol(timer.getId(), "histogram", fields));
  }

  private Stream<String> writeSummary(DistributionSummary summary) {
    final Stream<Field> fields = Stream.of(
        new Field("sum", summary.totalAmount()),
        new Field("count", summary.count()),
        new Field("mean", summary.mean()),
        new Field("upper", summary.max())
    );

    return Stream.of(influxLineProtocol(summary.getId(), "histogram", fields));
  }

  private Stream<String> writeLongTaskTimer(LongTaskTimer timer) {
    Stream<Field> fields = Stream.of(
        new Field("active_tasks", timer.activeTasks()),
        new Field("duration", timer.duration(getBaseTimeUnit()))
    );
    return Stream.of(influxLineProtocol(timer.getId(), "long_task_timer", fields));
  }

  private Stream<String> writeFunctionTimer(FunctionTimer timer) {
    double sum = timer.totalTime(getBaseTimeUnit());
    if (Double.isFinite(sum)) {
      Stream.Builder<Field> builder = Stream.builder();
      builder.add(new Field("sum", sum));
      builder.add(new Field("count", timer.count()));
      double mean = timer.mean(getBaseTimeUnit());
      if (Double.isFinite(mean)) {
        builder.add(new Field("mean", mean));
      }
      return Stream.of(influxLineProtocol(timer.getId(), "histogram", builder.build()));
    }
    return Stream.empty();
  }

  private Stream<String> writeMeter(Meter m) {
    List<Field> fields = new ArrayList<>();
    for (Measurement measurement : m.measure()) {
      double value = measurement.getValue();
      if (!Double.isFinite(value)) {
        continue;
      }
      String fieldKey = measurement.getStatistic().getTagValueRepresentation()
          .replaceAll("(.)(\\p{Upper})", "$1_$2").toLowerCase();
      fields.add(new Field(fieldKey, value));
    }
    if (fields.isEmpty()) {
      return Stream.empty();
    }
    Meter.Id id = m.getId();
    return Stream.of(influxLineProtocol(id, id.getType().name().toLowerCase(), fields.stream()));
  }
}

