package cn.quantgroup.boot;

import java.time.Duration;
import java.time.Instant;
import java.util.Map;
import javax.servlet.http.HttpServletRequest;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.slf4j.MDC;
import org.springframework.web.filter.CommonsRequestLoggingFilter;

@Slf4j
public class RequestLoggingFilter extends CommonsRequestLoggingFilter {

  @Override
  protected boolean shouldLog(HttpServletRequest request) {
    return logger.isInfoEnabled();
  }

  /**
   * Writes a log message before the request is processed.
   */
  @Override
  protected void beforeRequest(HttpServletRequest request, String message) {
    Map<String,String> traces = MDC.getCopyOfContextMap();
    String version = request.getHeader("x-b3-version");
    String channelId = request.getHeader("QG-Client-Id");
    String businessType = request.getHeader("QG-Business-Type");
    if(StringUtils.isNoneBlank(version)){
      traces.put("x-b3-version",version);
    }
    if(StringUtils.isNoneBlank(channelId)){
      traces.put("QG-Client-Id",channelId);
    }
    if(StringUtils.isNoneBlank(businessType)){
      traces.put("QG-Business-Type",businessType);
    }

    MDC.setContextMap(traces);

    if(StringUtils.containsIgnoreCase(message,"/health/check")){
      return;
    }
    Instant start = Instant.now();
    request.setAttribute("metric-start", start);
  }

  /**
   * Writes a log message after the request is processed.
   */
  @Override
  protected void afterRequest(HttpServletRequest request, String message) {
    String contentType = request.getContentType();
    if (StringUtils.containsAnyIgnoreCase(message,"/health/check","/actuator/health")
        ||StringUtils.containsAnyIgnoreCase(contentType,"multipart/form-data")) {
      return;
    }

    Instant start = (Instant) request.getAttribute("metric-start");
    Instant finished = Instant.now();
    long time = Duration.between(start, finished).toMillis();
    log.info("message:{},time:{}", message, time);
  }
}