# coding=utf-8

import numpy as np
import pandas as pd

import matplotlib.pyplot as plt

from sklearn.metrics import roc_auc_score

import pymysql
import pymongo

import os
import pickle
import warnings
import datetime
from dateutil.relativedelta import relativedelta
from collections import OrderedDict

warnings.filterwarnings('ignore')
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
plt.rcParams['savefig.dpi'] = 150

class AUCMonitor:
    '''
    定时间段查看AUC.
    '''

    def __init__(self, excel_path='./model_score.xlsx', sheet_name='model',
                 passdue_day=15, save_path='./auc/',
                 min_user_group=500, interval_days=15, min_auc=0.55,
                 date_list=('2019-03-01', '2019-03-15', '2019-03-31', '2019-04-15'),
                 if_read=True):

        # 考虑到数据库配置基本不变, 所以不设置创建对象时对应输入变量.

        self.mongo_client = pymongo.MongoClient(
            "mongodb://haoyue.shu:x2egwRHk7WhQ4So1@172.18.3.22:27017/?authSource=rc_mgo_feature_dp")
        self.mongo_db = self.mongo_client['rc_mgo_feature_dp']
        self.mongo_table = self.mongo_db['wf_audit_log_with_feature_online']

        # 读取整理在Excel中的模型相关信息.
        self.field_info_df = pd.read_excel(excel_path, sheet_name=sheet_name)
        self.field_name_list = self.field_info_df.field_name.tolist()
        self.field_query_list = self.field_info_df.field_query.tolist()
        self.field_query_name_dict = dict(zip(self.field_query_list, self.field_name_list))

        # 一些定义的常量
        self.passdue_day = passdue_day  # 逾期天数, 默认15.
        self.save_path = save_path  # 图片保存位置, 默认./image.
        self.min_user_group = min_user_group * interval_days / 30  # 最小客群数量.
        self.min_auc = min_auc

        # 将会从数据库中读取的数据.
        self.mongo_df = None

        # 时间数据.
        self.date_list = date_list
        self.date_range_list = []
        for i in range(len(self.date_list) - 1):
            self.date_range_list.append(self.date_list[i] + ' ~ ' + self.date_list[i + 1])

        # 统计数据记录.
        auc_cols = ['model_name', 'app_type', 'app_chan', 'group_name']
        for i in range(len(date_list) - 1):
            auc_cols.append(self.date_list[i] + ' ~ ' + self.date_list[i + 1] + 'NUM')
            auc_cols.append(self.date_list[i] + ' ~ ' + self.date_list[i + 1] + 'AUC')

        self.auc_info_df = pd.DataFrame(columns=auc_cols)

        # 程序数据读写模式.
        self.if_read = if_read  # 是否从从数据库读取.

        # 创建文件夹.
        if not os.path.exists(self.save_path):
            os.mkdir(self.save_path)
        if not os.path.exists(self.save_path + 'image/'):
            os.mkdir(self.save_path + 'image/')
        if not os.path.exists(self.save_path + 'data/'):
            os.mkdir(self.save_path + 'data/')
        if not os.path.exists(self.save_path + 'info/'):
            os.mkdir(self.save_path + 'info/')

    def query_mongo(self, condition, fields):
        '''
        连接MongoDB, 根据查询返回数据.
        :param condition: dict
        :param fields: dict
        :return: DataFrame
        '''
        try:
            return pd.DataFrame(list(self.mongo_table.find(condition, fields)))
        except:
            print('Mongo查询出现错误.')

    def int2str(self, x):
        '''
        将int转换为str, 用于日期.
        e.g. 5 --> 05
        :param x: int
        :return: str.
        '''
        if x >= 10:
            return str(x)
        else:
            return '0' + str(x)

    def helper_auc(self, user_group_name=None, df=None, info_dict=None, field=None):
        '''
        信息提取函数.
        :param user_group_name: str, 客群名称.
        :param df: Dataframe, 对应客群数据.
        :return: None.
        '''
        print('正在处理%s客群数据.' % user_group_name)
        info_dict[user_group_name] = OrderedDict()
        date_range_list = list(sorted(df['date_label'].unique().tolist()))
        if '' in date_range_list:
            date_range_list.remove('')

        df_g = df.groupby(['date_label'])['overdue'].agg({'overdue': ['count', 'sum', 'mean']})
        df_g.columns = ['_'.join(x) for x in df_g.columns.ravel()]
        df_g = df_g.reset_index()
        df_g = df_g.sort_values(['date_label'])

        for i, m in enumerate(date_range_list):
            amt = df_g.loc[df_g['date_label'] == m, 'overdue_count'].values
            # 某月样本量小于阈值, 放弃记录信息.
            # if amt < self.min_user_group:
            #     print('%s样本量过小, 放弃提取信息.' % m)
            #     continue
            info_dict[user_group_name][m] = {}
            info_dict[user_group_name][m]['NUM'] = amt
            info_dict[user_group_name][m]['overdue_ratio'] = df_g.loc[df_g['date_label'] == m, ['overdue_mean']]
            print('%s样本量: %d' % (m, amt))
            try:
                info_dict[user_group_name][m]['AUC'] = roc_auc_score(
                    df.loc[(df['date_label'] == m) & (df[field].notna()), 'overdue'],
                    df.loc[(df['date_label'] == m) & (df[field].notna()), field])
            except:
                info_dict[user_group_name][m]['AUC'] = np.nan
                print('AUC计算发生错误.')
        print('处理完成.')
        print('=' * 40)

    def plot_auc(self, field):
        # 分离数据.
        df_copy = self.mongo_df[
            [field, 'date_label', 'applied_type', 'applied_channel', 'overdue', 'passdue_day', 'applied_at']].copy()
        # 筛选出放款, 且逾期表现的数据.
        df_copy = df_copy[(df_copy['overdue'].notna()) & (df_copy[field].notna())]
        if df_copy.shape[0] == 0:
            print('仍在空跑.')
            return None
        # 统一时间格式.
        if repr(df_copy['applied_at'].dtype) == "dtype('O')":
            df_copy = df_copy.loc[
                (df_copy[field].notna()) & (df_copy['applied_at'].apply(lambda x: x[:10]) <= self.date_list[-1]) & (
                        df_copy[field] > 0) & (df_copy['passdue_day'].notna())]
        else:
            df_copy = df_copy.loc[(df_copy[field].notna()) & (
                    df_copy['applied_at'].apply(lambda x: x.strftime('%Y-%m-%d')) <= self.date_list[-1]) & (
                                          df_copy[field] > 0) & (df_copy['passdue_day'].notna())]

        # 包含各种信息的字典.
        # 如: {'全样本':
        #               {'时间段_0':
        #                       {'该时间段样本量': int.
        #                        '该时间段逾期率': float,
        #                        'auc': float}
        #                '时间段_1':
        #                       {'该时间段样本量': int.
        #                        '该时间段逾期率': float,
        #                        'auc': float}}}
        info_dict = {}

        # 全样本
        self.helper_auc('全样本', df_copy, info_dict, field)
        # 按申请类型划分.
        self.helper_auc('首申-全渠道', df_copy.loc[df_copy['applied_type'] == 1], info_dict, field)
        self.helper_auc('复申-全渠道', df_copy.loc[df_copy['applied_type'] == 2], info_dict, field)
        self.helper_auc('复贷-全渠道', df_copy.loc[df_copy['applied_type'] == 3], info_dict, field)

        # 按主要客群划分.
        ## 客群划分.
        ## user_group_dict = {'首申-融360': (1, 融360)}
        user_group_dict = {}
        app_type_dict = {1: '首申', 2: '复申', 3: '复贷'}
        df_copy_g = df_copy.groupby(['applied_type', 'applied_channel'])[field].count().sort_values(ascending=False)
        df_copy_g = df_copy_g.reset_index()
        ## 过滤小客群.
        df_copy_g = df_copy_g.loc[df_copy_g[field] > self.min_user_group]
        app_type_set = df_copy_g['applied_type'].unique()
        app_chan_set = df_copy_g['applied_channel'].unique()
        for app_type in app_type_set:
            for app_chan in app_chan_set:
                if df_copy_g.loc[
                    (df_copy_g['applied_type'] == app_type) & (df_copy_g['applied_channel'] == app_chan)].shape[0] != 0:
                    user_group_dict[app_type_dict[app_type] + '-' + app_chan] = (app_type, app_chan)
        del df_copy_g
        ## 按划分的客群处理数据.
        for user_group_name in user_group_dict:
            self.helper_auc(user_group_name,
                            df_copy.loc[(df_copy['applied_type'] == user_group_dict[user_group_name][0]) & (
                                    df_copy['applied_channel'] == user_group_dict[user_group_name][1])], info_dict,
                            field)
        # 过滤不包含信息的客群.
        remove_list = []
        for user_group_name in info_dict:
            if not info_dict[user_group_name]:
                remove_list.append(user_group_name)
        for user_group_name in remove_list:
            del info_dict[user_group_name]

        # 画图.
        print('开始画图.')
        print('=' * 40)
        for user_group_name in info_dict:
            print(self.field_query_name_dict[field] + '-' + user_group_name)
            plt.figure(figsize=(16, 8))
            auc_list = []
            num_list = []
            overdue_ratio_list = []
            label = ''
            for m in self.date_range_list:
                if m in info_dict[user_group_name]:
                    auc_list.append(info_dict[user_group_name][m]['AUC'])
                    num_list.append(info_dict[user_group_name][m]['NUM'])
                    overdue_ratio_list.append(info_dict[user_group_name][m]['overdue_ratio'])
                    label = label + '%s AUC: %.3f 样本量: %d\n' % (m, info_dict[user_group_name][m]['AUC'],
                                                                info_dict[user_group_name][m]['NUM'])
                else:
                    auc_list.append(np.nan)
                    num_list.append(np.nan)
                    overdue_ratio_list.append(np.nan)
                    label = label + '%s AUC: %s 样本量: %s\n' % (m, 'NaN', 'NaN')
            plt.plot(range(len(self.date_range_list)),
                     auc_list, '*--',
                     label=label)
            plt.legend(loc='upper right', fontsize=18)
            plt.title(self.field_query_name_dict[field] + '-' + user_group_name, fontdict={'fontsize': 40})
            plt.subplots_adjust(left=0.03, right=0.99, top=0.91, bottom=0.03)
            plt.savefig(self.save_path + 'image/' + self.field_query_name_dict[field] + '-' + user_group_name)
            plt.show()

        # 保存统计信息.

        def app_type(data):
            data = data.split('-')
            if len(data) == 1:
                return '全类型'
            else:
                return data[0]

        def app_channel(data):
            try:
                data = data.split('-')[1:]
                return '-'.join(data)
            except:
                return '全渠道'
        for user_group_name in info_dict:
            tmp_dict = {'model_name': [self.field_query_name_dict[field]],
                        'app_type': [app_type(user_group_name)],
                        'app_chan': [app_channel(user_group_name)],
                        'group_name': [user_group_name]}
            for m in info_dict[user_group_name]:
                tmp_dict[m + 'NUM'] = [int(info_dict[user_group_name][m]['NUM'])]
                tmp_dict[m + 'AUC'] = [round(info_dict[user_group_name][m]['AUC'], 3)]

            self.auc_info_df = self.auc_info_df.append(pd.DataFrame(tmp_dict))

    def abnormal_auc(self):
        def is_abnormal_auc(data):
            for i in data.index:
                if 'AUC' in i and pd.notna(data[i]) and data[i] < self.min_auc:
                    return True
            return False

        self.auc_info_df['is_abnormal'] = self.auc_info_df.apply(is_abnormal_auc, axis=1)

    def run(self):
        def func_0(data):
            try:
                return int(int(data) + 1)
            except:
                return np.nan

        # 获取MongoDB数据.
        if self.if_read:
            condition = {'wf_created_at': {'$gte': '%s 00:00:00' % self.date_list[0],
                                           '$lte': '%s 00:00:00' % self.date_list[-1]},
                         'passdue_day': {'$ne': None}}
            fields = {'wf_biz_no': 1, 'wf_created_at': 1, 'wf_loan_type': 1,
                      'passdue_day': 1, 'wf_biz_channel': 1, 'applied_channel_cn': 1,
                      'lam_transaction_status': 1,
                      'repayment_status': 1
                      }
            for f in self.field_query_list:  # 加入Excel中预置的模型分名称
                fields[f] = 1
            self.mongo_df = self.query_mongo(condition, fields)
            print('MongoDB数据获取成功.')
            self.mongo_df.to_csv(self.save_path + 'data/mongo_data.csv', index=False)
            self.mongo_df = pd.read_csv(self.save_path + 'data/mongo_data.csv')
            self.mongo_df['applied_type'] = self.mongo_df['wf_loan_type'].apply(func_0)
            self.mongo_df['applied_at'] = self.mongo_df['wf_created_at']
            self.mongo_df['applied_from'] = self.mongo_df['wf_biz_channel']
            self.mongo_df['applied_channel'] = self.mongo_df['applied_channel_cn']

            del self.mongo_df['wf_loan_type']
            del self.mongo_df['wf_created_at']
            del self.mongo_df['wf_biz_channel']
            del self.mongo_df['applied_channel_cn']

        else:
            self.mongo_df = pd.read_csv(self.save_path + 'data/mongo_data.csv')
        self.mongo_df = self.mongo_df.loc[self.mongo_df['applied_type'].notna()]


        # 定义逾期用户.
        def overdue(data):
            if pd.isnull(data):
                return np.nan
            else:
                return float(data > self.passdue_day)

        self.mongo_df['is_loan'] = ((self.mongo_df.lam_transaction_status.isin([2, 5])) &
                                    (self.mongo_df.repayment_status.isin([0, 1, 2, 3, 5, 6])))
        self.mongo_df = self.mongo_df.loc[self.mongo_df['is_loan']]
        self.mongo_df['overdue'] = self.mongo_df['passdue_day'].apply(overdue)
        del self.mongo_df['lam_transaction_status']
        del self.mongo_df['repayment_status']

        # 清洗时间格式, 使其转换成统一的字符串格式.
        if repr(self.mongo_df['applied_at'].dtype) == "dtype('O')":
            self.mongo_df['applied_at'] = self.mongo_df['applied_at'].apply(lambda x: x[:10])
        else:
            self.mongo_df['applied_at'] = self.mongo_df['applied_at'].apply(lambda x: x.strftime('%Y-%m-%d'))

        # 清洗数据.
        def clean_data(data):
            if pd.isnull(data):
                return np.nan
            try:
                if data <= 0 or data > 999999:
                    return np.nan
                return float(data)
            except:
                return np.nan

        na_field_list = []
        for field in self.field_query_list:
            if field in self.mongo_df.columns.tolist():
                print('正在清洗%s' % self.field_query_name_dict[field])
                self.mongo_df[field] = self.mongo_df[field].apply(clean_data)
            else:
                na_field_list.append(field)
        ## 去除因为一些原因未抽取到的字段.
        print('不包含以下字段:')
        for field in na_field_list:
            print(self.field_query_name_dict[field])
            idx = self.field_query_list.index(field)
            self.field_query_list.pop(idx)
            self.field_name_list.pop(idx)
            del self.field_query_name_dict[field]

        # 数据按时间划分.
        self.mongo_df['date_label'] = ''
        for i in range(len(self.date_list) - 1):
            self.mongo_df.loc[
                (self.mongo_df['applied_at'] >= '%s' % self.date_list[i]) &
                (self.mongo_df['applied_at'] < '%s' % self.date_list[i + 1]),
                'date_label'] = self.date_list[i] + ' ~ ' + self.date_list[i + 1]

        # 画图.
        print('开始画图-AUC.')
        for field in self.field_query_list:
            self.plot_auc(field)

        # 检测是否异常.
        self.abnormal_auc()

        # 保存统计信息.
        self.auc_info_df.to_csv(self.save_path + 'info/auc_info.csv', index=False)
        print('统计信息保存成功.')
