package cn.quantgroup.tech.brave.slf4j;

import brave.internal.HexCodec;
import brave.internal.Nullable;
import brave.propagation.CurrentTraceContext;
import brave.propagation.TraceContext;
import org.slf4j.MDC;

import static brave.internal.HexCodec.lowerHexEqualsTraceId;
import static brave.internal.HexCodec.lowerHexEqualsUnsignedLong;

/**
 * Adds {@linkplain MDC} properties TRACE_ID, PARENT_ID and SPAN_ID when a {@link
 * brave.Tracer#currentSpan() span is current}. These can be used in log correlation.
 */
public final class MDCCurrentTraceContext extends CurrentTraceContext {

    private static final String TRACE_ID = "X-B3-TraceId";
    private static final String SPAN_ID = "X-B3-SpanId";
    private static final String PARENT_ID = "X-B3-ParentId";
    private static final String EXPORTABLE = "X-Span-Export";

    public static MDCCurrentTraceContext create() {
        return create(CurrentTraceContext.Default.inheritable());
    }

    public static MDCCurrentTraceContext create(CurrentTraceContext delegate) {
        return new MDCCurrentTraceContext(delegate);
    }

    final CurrentTraceContext delegate;

    MDCCurrentTraceContext(CurrentTraceContext delegate) {
        if (delegate == null) throw new NullPointerException("delegate == null");
        this.delegate = delegate;
    }

    @Override
    public TraceContext get() {
        return delegate.get();
    }

    @Override
    public Scope newScope(@Nullable TraceContext currentSpan) {
        return newScope(currentSpan, MDC.get(TRACE_ID), MDC.get(SPAN_ID), MDC.get(EXPORTABLE));
    }

    @Override
    public Scope maybeScope(@Nullable TraceContext currentSpan) {
        String previousTraceId = MDC.get(TRACE_ID);
        String previousSpanId = MDC.get(SPAN_ID);
        String sampled = MDC.get(EXPORTABLE);

        if (currentSpan == null) {
            if (previousTraceId == null) {
                return Scope.NOOP;
            }
            return newScope(null, previousTraceId, previousSpanId, sampled);
        }
        if (lowerHexEqualsTraceId(previousTraceId, currentSpan)
                && lowerHexEqualsUnsignedLong(previousSpanId, currentSpan.spanId())) {
            return Scope.NOOP;
        }
        return newScope(currentSpan, previousTraceId, previousSpanId, sampled);
    }

    // all input parameters are nullable
    Scope newScope(TraceContext currentSpan, String previousTraceId, String previousSpanId, String sampled) {
        String previousParentId = MDC.get(PARENT_ID);
        if (currentSpan != null) {
            maybeReplaceTraceContext(currentSpan, previousTraceId, previousParentId, previousSpanId, sampled);
        } else {
            MDC.remove(TRACE_ID);
            MDC.remove(PARENT_ID);
            MDC.remove(SPAN_ID);
            MDC.remove(EXPORTABLE);
        }

        Scope scope = delegate.newScope(currentSpan);
        class MDCCurrentTraceContextScope implements Scope {
            @Override
            public void close() {
                scope.close();
                replace(TRACE_ID, previousTraceId);
                replace(PARENT_ID, previousParentId);
                replace(SPAN_ID, previousSpanId);
                //true = 采样了. false = 未采样. null = 决定不了...等会再说
                replace(EXPORTABLE, sampled);
            }
        }
        return new MDCCurrentTraceContextScope();
    }

    void maybeReplaceTraceContext(
            TraceContext currentSpan,
            String previousTraceId,
            @Nullable String previousParentId,
            String previousSpanId,
            @Nullable String sampled
    ) {
        MDC.put(EXPORTABLE, String.valueOf(currentSpan.sampled()));
        boolean sameTraceId = lowerHexEqualsTraceId(previousTraceId, currentSpan);
        if (!sameTraceId) {
            MDC.put(TRACE_ID, currentSpan.traceIdString());
        }

        long parentId = currentSpan.parentIdAsLong();
        if (parentId == 0L) {
            MDC.remove(PARENT_ID);
        } else {
            boolean sameParentId = lowerHexEqualsUnsignedLong(previousParentId, parentId);
            if (!sameParentId) MDC.put(PARENT_ID, HexCodec.toLowerHex(parentId));
        }

        boolean sameSpanId = lowerHexEqualsUnsignedLong(previousSpanId, currentSpan.spanId());
        if (!sameSpanId) MDC.put(SPAN_ID, HexCodec.toLowerHex(currentSpan.spanId()));
    }

    static void replace(String key, @Nullable String value) {
        if (value != null) {
            MDC.put(key, value);
        } else {
            MDC.remove(key);
        }
    }
}
