package cn.quantgroup.big.stms.sys.service.impl;

import cn.quantgroup.big.stms.common.exception.BizException;
import cn.quantgroup.big.stms.common.result.ResultCode;
import cn.quantgroup.big.stms.common.utils.AESUtils;
import cn.quantgroup.big.stms.common.utils.Constants;
import cn.quantgroup.big.stms.sys.dao.ClientDao;
import cn.quantgroup.big.stms.sys.dto.Oauth2TokenDto;
import cn.quantgroup.big.stms.sys.model.Client;
import cn.quantgroup.big.stms.sys.service.ClientService;
import com.google.gson.Gson;
import lombok.extern.slf4j.Slf4j;
import org.apache.commons.lang3.StringUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.stereotype.Service;

import java.nio.charset.StandardCharsets;
import java.util.Base64;
import java.util.UUID;
import java.util.concurrent.TimeUnit;

@Service
@Slf4j
public class ClientServiceImpl implements
    ClientService {

  @Autowired
  private ClientDao clientDao;
  @Autowired
  @Qualifier("stringRedisTemplate")
  private RedisTemplate<String, String> stringRedisTemplate;
  @Value("${oauth.authorize.token.expire}")
  private Long tokenExpire;
  @Autowired
  private Gson gson;

  private static final Long REFRESH_TOKEN_EXPIRE = 180L;

  @Override
  public Client findByClientId(String clientId) {
    return clientDao.findByClientId(clientId);
  }

  @Override
  public Boolean checkSecret(String clientId, String clientSecret) {
    Client client = findByClientId(clientId);
    if (null == client) {
      return false;
    }
    String correctSecret = AESUtils.decrypt(client.getClientSecret());
    return StringUtils.equals(correctSecret, clientSecret);
  }

  @Override
  public Oauth2TokenDto getTokenByCode(String code, String clientId) {
    String correctCode = stringRedisTemplate.opsForValue()
        .get(Constants.OAUTH2_CODE_KEY + clientId);
    if (StringUtils.isBlank(code)) {
      throw new BizException("授权码不能为空", ResultCode.PARAM_ERROR);
    }
    if (!StringUtils.equals(correctCode, code)) {
      throw new BizException(ResultCode.AUTHORIZE_CODE_MISMATCH);
    }
    //生成token
    String accessToken = Base64.getEncoder().encodeToString(UUID.randomUUID().toString().getBytes(
        StandardCharsets.UTF_8));
    String refreshToken = Base64.getEncoder().encodeToString(UUID.randomUUID().toString().getBytes(
        StandardCharsets.UTF_8));

    Client client = findByClientId(clientId);

    //保存token
    Oauth2TokenDto oauth2TokenDto = new Oauth2TokenDto();
    oauth2TokenDto.setAccessToken(accessToken);
    oauth2TokenDto.setRefreshToken(refreshToken);
    oauth2TokenDto.setClientId(clientId);
    oauth2TokenDto.setExpiresIn(tokenExpire);
    oauth2TokenDto.setScopes(client.getScopes());
    saveLoginToken(client, oauth2TokenDto);

    //删除授权码
    stringRedisTemplate.delete(Constants.OAUTH2_CODE_KEY + clientId);
    return oauth2TokenDto;
  }

  @Override
  public Oauth2TokenDto getTokenByRefreshToken(String refreshToken, String clientId) {
    if (StringUtils.isBlank(refreshToken)) {
      throw new BizException("refresh_token不能为空", ResultCode.PARAM_ERROR);
    }
    String value = stringRedisTemplate.opsForValue().get(Constants.OAUTH2_TOKEN_KEY + refreshToken);
    if (StringUtils.isBlank(value)) {
      throw new BizException(ResultCode.REFRESH_TOKEN_EXPIRE);
    }
    Client client = gson.fromJson(value, Client.class);
    if (StringUtils.equals(client.getClientId(), clientId)) {
      throw new BizException(ResultCode.REFRESH_TOKEN_MISMATCH);
    }
    //删除原token
    String oldAccessToken = stringRedisTemplate.opsForValue()
        .get(Constants.OAUTH2_TOKEN_KEY + client.getId());
    if (StringUtils.isNotEmpty(oldAccessToken)) {
      stringRedisTemplate.delete(Constants.OAUTH2_TOKEN_KEY + client.getId());
      stringRedisTemplate.delete(Constants.OAUTH2_TOKEN_KEY + oldAccessToken);
    }
    stringRedisTemplate.delete(Constants.OAUTH2_TOKEN_KEY + refreshToken);
    stringRedisTemplate.delete(Constants.OAUTH2_TOKEN_KEY + "refreshToken:" + client.getId());

    //生成新token
    String newAccessToken = Base64.getEncoder()
        .encodeToString(UUID.randomUUID().toString().getBytes(
            StandardCharsets.UTF_8));
    String newRefreshToken = Base64.getEncoder()
        .encodeToString(UUID.randomUUID().toString().getBytes(
            StandardCharsets.UTF_8));
    //保存token
    Oauth2TokenDto oauth2TokenDto = new Oauth2TokenDto();
    oauth2TokenDto.setAccessToken(newAccessToken);
    oauth2TokenDto.setRefreshToken(newRefreshToken);
    oauth2TokenDto.setClientId(clientId);
    oauth2TokenDto.setExpiresIn(tokenExpire);
    oauth2TokenDto.setScopes(client.getScopes());
    saveLoginToken(client, oauth2TokenDto);
    return oauth2TokenDto;
  }

  private void saveLoginToken(Client client, Oauth2TokenDto oauth2TokenDto) {
    stringRedisTemplate.opsForValue()
        .set(Constants.OAUTH2_TOKEN_KEY + client.getId(), oauth2TokenDto.getAccessToken(),
            tokenExpire,
            TimeUnit.SECONDS);
    stringRedisTemplate.opsForValue()
        .set(Constants.OAUTH2_TOKEN_KEY + oauth2TokenDto.getAccessToken(), gson.toJson(client),
            tokenExpire,
            TimeUnit.SECONDS);
    stringRedisTemplate.opsForValue()
        .set(Constants.OAUTH2_TOKEN_KEY + oauth2TokenDto.getRefreshToken(), gson.toJson(client),
            REFRESH_TOKEN_EXPIRE,
            TimeUnit.SECONDS);
    stringRedisTemplate.opsForValue()
        .set(Constants.OAUTH2_TOKEN_KEY + "refreshToken:" + client.getId(),
            oauth2TokenDto.getRefreshToken(),
            REFRESH_TOKEN_EXPIRE, TimeUnit.DAYS);
  }
}
