# -*- coding: utf-8 -*-
"""
Created on Mon Dec  3 18:18:12 2018

@author: Jason Wang
"""
import time
import os
import pymysql
import pandas as pd
import numpy as np
import openpyxl
import decimal
import matplotlib.pyplot as plt
import os
from matplotlib.font_manager import FontProperties
from matplotlib.lines import Line2D
import datetime
import sklearn.metrics
from django.db import transaction, DatabaseError


sql_bins = '''
SELECT @modelVar,transacted,IF(passdue_day>@passdueday,1,0) as overdue FROM risk_analysis
WHERE applied_at BETWEEN 
(SELECT date_format(applied_at,'%Y-%m-%d')
FROM risk_analysis
WHERE !ISNULL(@modelVar) AND transacted=1 and applied_from IN (@channelID)
ORDER BY applied_at asc
LIMIT 1) AND DATE_ADD((SELECT date_format(applied_at,'%Y-%m-%d')
FROM risk_analysis
WHERE !ISNULL(@modelVar) AND transacted=1 and applied_from IN (@channelID)
ORDER BY applied_at asc
LIMIT 1),INTERVAL 30 DAY)
AND applied_from IN (@channelID)
AND applied_type IN (@appliedType)
AND !ISNULL(@modelVar)
AND @modelVar > 0
'''


sql_observation = '''
SELECT date_format(applied_at,'%Y-%m') as applied_at,@modelVar 
FROM risk_analysis
WHERE DATE_FORMAT(applied_at,'%Y-%m')
BETWEEN DATE_FORMAT(DATE_ADD(NOW(),INTERVAL -3 MONTH),'%Y-%m')
AND DATE_FORMAT(DATE_ADD(NOW(),INTERVAL -1 MONTH),'%Y-%m')
AND applied_from IN (@channelID)
AND applied_type IN (@appliedType)
AND !ISNULL(@modelVar)
'''

######## calculate with T-N mon ###########

sql_passdueday = '''
(SELECT order_no,'T-1' as applied_at,@modelVar,IF(passdue_day > @passdue_day,1,0) as overdue
FROM risk_analysis
WHERE DATE_FORMAT(deadline,'%Y-%m-%d') >= DATE_FORMAT(DATE_ADD(NOW(),INTERVAL -45 DAY),'%Y-%m-%d') and DATE_FORMAT(deadline,'%Y-%m-%d') < DATE_FORMAT(DATE_ADD(NOW(),INTERVAL -15 DAY),'%Y-%m-%d')
AND applied_from IN (@channelID)
AND applied_type IN (@appliedType)
AND transacted = 1)
UNION ALL
(SELECT order_no,'T-2' as applied_at,@modelVar,IF(passdue_day > @passdue_day,1,0) as overdue
FROM risk_analysis
WHERE DATE_FORMAT(deadline,'%Y-%m-%d') >= DATE_FORMAT(DATE_ADD(NOW(),INTERVAL -75 DAY),'%Y-%m-%d') and DATE_FORMAT(deadline,'%Y-%m-%d') < DATE_FORMAT(DATE_ADD(NOW(),INTERVAL -45 DAY),'%Y-%m-%d')
AND applied_from IN (@channelID)
AND applied_type IN (@appliedType)
AND transacted = 1)
UNION ALL
(SELECT order_no,'T-3' as applied_at,@modelVar,IF(passdue_day > @passdue_day,1,0) as overdue
FROM risk_analysis
WHERE DATE_FORMAT(deadline,'%Y-%m-%d') >= DATE_FORMAT(DATE_ADD(NOW(),INTERVAL -105 DAY),'%Y-%m-%d') and DATE_FORMAT(deadline,'%Y-%m-%d') < DATE_FORMAT(DATE_ADD(NOW(),INTERVAL -75 DAY),'%Y-%m-%d')
AND applied_from IN (@channelID)
AND applied_type IN (@appliedType)
AND transacted = 1) 
'''

############ calculate with natural mon #############
"""
sql_passdueday = '''
SELECT date_format(loan_start_date,'%Y-%m') as applied_at,@modelVar,IF(passdue_day > @passdueday,1,0) as overdue
FROM risk_analysis
WHERE DATE_FORMAT(loan_start_date,'%Y-%m')
BETWEEN DATE_FORMAT(DATE_ADD(NOW(),INTERVAL -4 MONTH),'%Y-%m-%s')
AND DATE_FORMAT(DATE_ADD(NOW(),INTERVAL -2 MONTH),'%Y-%m')
AND applied_from IN (@channelID)
AND applied_type IN (@appliedType)
AND !ISNULL(@modelVar)
AND transacted = 1
'''
"""

passdue_day = 15

#AND applied_from IN (@channelID)
##################################### db config ###############################
risk_analysis_config = {'user' : 'fengkong_read_only',
                        'password' : 'mT2HFUgI',
                        'host' : '172.20.6.9',
                        'port' : 9030,
                        'database' : 'risk_analysis',
                        'encoding' : 'utf8'}


#################################################################################
path = "../plot/PSI_VAL/"
mapping_path = "./query_score.xlsx"

mapping = pd.read_excel(mapping_path,sheet_name='score_risk_anlysis')

modelType = mapping.description.tolist()
modelList = mapping.score.tolist()
appliedTypeList = mapping.appliedType.tolist()
channelIDList = mapping.channel.tolist()

#modelBound_dict = mapping[['feature','boundary']].set_index('feature').boundary.to_dict()

del mapping

appliedType_type = {'1,2,3':'总体','1':'首申','2':'复申','3':'复贷'}

passdueday = 15 #more than N days (fstOverdue N+)

def connect2DB(db_config):
    db = pymysql.connect(
        host = db_config['host'],
        port = db_config['port'],
        user = db_config['user'],
        passwd = db_config['password'],
        db = db_config['database'],
        charset = db_config['encoding'])
    return db


def query_sql(sql,db_config=risk_analysis_config):
    try:
        conn = connect2DB(db_config)
        df = pd.read_sql(sql,conn)
        conn.close() 
        return df
    except Exception as e:
        return 0 

################################### plot PSI ##################################
 #+'\nmissing:'+str(missing[int(i/2)])+'%'   
def plotPSI(title,y_list,dateList,psi,missing,rows,cols,table_value,save_path):
    plt.rcParams['font.sans-serif'] = ['SimHei']
    plt.rcParams['axes.unicode_minus'] = False
    plt.rcParams['savefig.dpi'] = 226 #图片像素 
    plt.rcParams['figure.dpi'] = 100 #分辨率
    fig,axs = plt.subplots(1,1,figsize=(16,9),linewidth=0.1)
        
    for y_index in range(len(y_list)):
        y = y_list[y_index]
        x = range(len(y))
        axs.plot(x,y,marker='o',label=dateList[y_index][0:7] + ' PSI:'+str(psi[y_index])+'\n缺失率:'+str(missing[y_index])+'%')
      
    the_table = plt.table(cellText=table_value,
                      rowLabels=rows,
                      colLabels=cols,
                      colWidths=[0.91 / (len(cols)-1)] * len(cols),
                      loc='bottom')
    
    the_table.auto_set_font_size(False)
    the_table.set_fontsize(8)
    fig.subplots_adjust(bottom=0.2)   
    plt.grid()
    plt.ylabel('各分段样本占比'+' (%)')
    plt.legend()
    plt.xticks([])
    #plt.vlines(xrange(len(cols))0],y,color='lightgrey',linestyle='--')
    fig.suptitle(title)    
    plt.savefig(save_path + title + ".png")
    plt.show()    
    return 1  
  
########################### validation liftchart###############################
def plotLiftChart(title,y_list,dateList,aucri,auc,rows,cols,table_value,save_path):
    plt.rcParams['font.sans-serif'] = ['SimHei']
    plt.rcParams['axes.unicode_minus'] = False
    plt.rcParams['savefig.dpi'] = 226 #图片像素 
    plt.rcParams['figure.dpi'] = 100 #分辨率
    fig,axs = plt.subplots(1,1,figsize=(16,9),linewidth=0.1)
    
    for y_index in range(len(y_list)):
        y = y_list[y_index]
        x = range(len(y))
        axs.plot(x,y,marker='o',label=dateList[y_index][0:7] + ' (AUCRI:' + str(aucri[y_index])+ ') AUC: ' + str(auc[y_index]))

    the_table = plt.table(cellText=table_value,
                      rowLabels = rows,
                      colLabels = cols,
                      colWidths = [0.91 / (len(cols)-1)] * len(cols),
                      loc = 'bottom')
    the_table.auto_set_font_size(False)
    the_table.set_fontsize(8)
    fig.subplots_adjust(bottom = 0.2)   
    plt.legend()
    plt.grid()
    plt.ylabel('贷后首逾'+str(passdueday)+'+ (%)')
    plt.xticks([])    
    fig.suptitle(title)
    plt.savefig(save_path + title + ".png")
    plt.show()    
    return 1     

###############################################################################
#def dataManipul(df,keyword,interval):
#    
#    # df count of all records
#    
#    df_count = df[['applied_at',keyword]].fillna(0).groupby('applied_at').count()[keyword]
#    df_zeros = pd.Series(np.zeros(df_count.shape),index = df_count.index)
#    
#    df_missing = df[['applied_at',keyword]].groupby('applied_at').count()    
#    df_missing = pd.concat([df_zeros,df_missing], axis = 1, sort = True).fillna(0)[keyword]
#   
##    df_shape = pd.DataFrame(np.zeros(df_count.shape))
##    
##    df_missing = df[df[keyword].isnull()].fillna(0).groupby('applied_at')[keyword].count()
##    df_missing = df_shape + df_missing
#    missing_rate = df_missing / df_count * 100
#    
#    df_zero = df[df[keyword] == 0].groupby('applied_at')[keyword].count()
#    df_zero = pd.concat([df_zeros,df_zero], axis = 1, sort = True).fillna(0)[keyword]
#    zero_rate = df_zero / df_count * 100
#    
#    df_noneNA = df.dropna(axis = 0)
#    df_sum = df_noneNA.groupby('applied_at').agg(['mean','std','count'])
#    
#    cols = df_count.index
#    return zero_rate,missing_rate,cols,df_sum 

def dataManipul(df,keyword,interval):   
    # df count of all records    
#    missing_rate = {}
#    df_count = df[['applied_at','bins']].groupby('applied_at')    
    # count dataframe separated by mon
    # set negative as null 
    df.dropna(axis=0)[keyword] = df.dropna(axis=0)[keyword].map(lambda x : np.nan if x < 0 else x )
    df_noneNA = df.dropna(axis = 0)
        
    df_count = df[['applied_at',keyword]].fillna(0).groupby('applied_at').count()
    df_zeros = pd.Series(np.zeros(df_count[keyword].shape),index = df_count.index)
    df_missing = df_count - df_noneNA[['applied_at',keyword]].groupby('applied_at').count()
    df_missing = pd.concat([df_zeros,df_missing],axis=1)[keyword].fillna(0)
    missing_rate = df_missing / df_count[keyword].replace(0,1) * 100

    df_noneNA['bins'] = pd.cut(df_noneNA[keyword],interval,precision=6)
    cols = df_noneNA['bins'].value_counts().sort_index().index.astype('str')
    df_count = df_noneNA[['applied_at','bins',keyword]].groupby(['applied_at','bins']).count()
    df_zeros = pd.Series(np.zeros(df_count[keyword].shape),index = df_count.index)
    df_zero = df_noneNA[df_noneNA[keyword] == 0][['applied_at','bins',keyword]].groupby(['applied_at','bins']).count()
    df_zero = pd.concat([df_zeros,df_zero],axis=1)[keyword].fillna(0)
    zero_rate = df_zero / df_count[keyword].replace(0,1) * 100
    y = df_count / df_noneNA[['applied_at',keyword]].groupby('applied_at').count() * 100
    rows = y.index.levels[0].tolist()
    
    return zero_rate.round(1),missing_rate.round(1),rows,cols,y[keyword].round(1),df_count[keyword]

def psi_bins(df,keyword,interval):
    df.loc[:,'bins'] = pd.cut(df[keyword],interval,precision=6)
    BM = df.groupby('bins').count()[keyword]
    BM_count = BM / BM.values.sum() * 100
    return BM_count

# draw liftchart
def liftchart(df,keyword,interval):
    # split bins with scores
        
    #nothing,interval = pd.qcut(df[df.loc[:,keyword]>0][keyword],10,retbins=True,duplicates='drop')
    # delete 'nothing' var cause its useless
    if len(df[df.loc[:,keyword]<0][keyword])>0:       
        bins_interval = interval.tolist()       
        bins_interval.append(-10000000)
        bins_interval.sort()
    else:
        bins_interval = interval
    df.loc[:,'bins'] = pd.cut(df[keyword],bins_interval,precision=6)
    # count of sample
    df_count = df[['applied_at','bins','overdue']].groupby(['applied_at','bins']).count()    
    df_zeros = pd.Series(np.zeros(df_count['overdue'].shape),index = df_count.index)
    # overdue samples  
    df = df[df.overdue == 1]    
    #df.loc[:,'bins'] = pd.cut(df[keyword],interval)
    df_overdue = df[['applied_at','bins','overdue']].groupby(['applied_at','bins']).count()    
    df_overdue = pd.concat([df_zeros,df_overdue],axis=1)['overdue'].fillna(0)
    
    y = df_overdue / df_count['overdue'].replace(0,1) * 100    
    rows = y.index.levels[0].tolist()
    cols = df['bins'].value_counts().sort_index().index.astype('str').tolist()

    return df_count['overdue'],df_overdue,y.round(3),rows,cols


# extract channel list where except recalling channel 
sql_channel = '''
SELECT DISTINCT(applied_from),applied_channel FROM risk_analysis
WHERE transacted = 1
AND real_loan_amount > 20000
AND loan_start_date >= DATE_FORMAT(DATE_ADD(NOW(),INTERVAL -1 MONTH),'%Y-%m-01') 
AND loan_start_date < DATE_FORMAT(NOW(),'%Y-%m-01')
and applied_from not in (159481,159486,159528)
'''



channel = {'1,214,217,198':'内部','159507':'浅橙','159537':'360金融','333':'融360','159384,159483':'平安','159561':'51公积金API'}

channelId = query_sql(sql_channel).applied_from
l=''
for i in channel.keys():
    l = l + i+','
l = eval('['+l+']')   
channel[str(channelId[channelId.map(lambda x : True if x not in l else False)].tolist()).strip('[').strip(']')] = '其他渠道'    
channel[str(channelId.tolist()).strip('[').strip(']')] = '全部渠道'



# traverse each model & applied_type & channelbins_interval
for modelVar in modelList:
    print('model: ',modelVar)
    for appliedType in str(appliedTypeList[modelList.index(modelVar)]).split(';'):
#        print('appliedType',appliedType)
#        print('appliedTypeList[model_index]',appliedTypeList[modelList.index(modelVar)])
        for channelID in channel.keys():
            try:
                print('channelID:',channelID)
                                
                df_bins = query_sql(sql_bins.replace('@modelVar',modelVar).replace('@appliedType',appliedType).replace('@channelID',channelID).replace('@passdueday',str(passdueday))).dropna(axis=0)
            
                df_observation = query_sql(sql_observation.replace('@modelVar',modelVar).replace('@appliedType',appliedType).replace('@channelID',channelID)) 
                
                df_observation.loc[:,modelVar] = df_observation.loc[:,modelVar].map(lambda x : np.nan if x < 0 else x)
            #df_bins = df_bins.apply(lambda x :np.nan if x < 0 else x)
                Nothing,interval = pd.qcut(df_bins.loc[:,modelVar],10,retbins=True,precision=6,duplicates='drop')
                interval[0] = 0
                del Nothing
                BM_count = psi_bins(df_bins,modelVar,interval)
                zero_rate,missing_rate,dateList,cols,y,count = dataManipul(df_observation,modelVar,np.array(interval).round(6))
                #df_observation_with_bin = pd.cut(df_observation.dropna(axis=0)[modelVar],interval)
               # del df_bins   
                del interval
                value_tab = []
                rows = []
                y_list = []
                psi = []
                # plot line separated by mon 
                for mon in dateList:
                    y_list.append(y.loc[mon].values)
                    #value_tab.append(y.loc[mon].astype('str')+'%')
                    value_tab.append(count.loc[mon].astype('str')+'(zeroR:'+zero_rate.loc[mon].astype('str')+'%)')
                    #rows.append(str(mon)+' Value');
                    rows.append(str(mon)+' Count')
                    #(y-10).sum() / np.log10(y/10)
                    psi.append((((y.loc[mon]-BM_count) * np.log10(y.loc[mon]/BM_count)).sum()/100).round(3))
                plotPSI(modelType[modelList.index(modelVar)]+'-'+appliedType_type[appliedType]+'-' + channel[channelID] + ' PSI',y_list,dateList,psi,missing_rate,rows,cols,value_tab,path)

            except Exception as e:
                print('psi exception',e)
            try:
                # Overdue dataframe
                df_bins_auc = df_bins[df_bins.transacted == 1]
                del df_bins
                auc_BM = sklearn.metrics.roc_auc_score(df_bins_auc.overdue, df_bins_auc.loc[:,modelVar])
                print('AUC_BM: ',auc_BM)
                Nothing,interval = pd.qcut(df_bins_auc.loc[:,modelVar],10,retbins=True,precision=6,duplicates='drop')
                interval[0] = 0
                
                del Nothing               
                df_passdueday = query_sql(sql_passdueday.replace('@modelVar',modelVar).replace('@appliedType',appliedType).replace('@channelID',channelID).replace('@passdue_day',str(passdueday)))
                count,df_overdue,y,dateList,cols = liftchart(df_passdueday,modelVar,np.array(interval).round(6))
                
                value_tab = []
                rows = []
                y_list = []
                aucri = []
                auc = []
                for mon in dateList:
                    y_list.append(y.loc[mon].values)
                    #value_tab.append(y.loc[mon].astype('str')+'%')
                    value_tab.append(df_overdue.loc[mon].astype('str') + ' (总计 ' + count.loc[mon].astype('str') + ')' )
                    #rows.append(str(mon)+' OverdueRate');
                    rows.append(str(mon)+' Count')
                    df_passdueday = df_passdueday.dropna(axis=0)
                    aucri.append(round((sklearn.metrics.roc_auc_score(df_passdueday[df_passdueday.applied_at==mon].overdue, df_passdueday[df_passdueday.applied_at==mon].loc[:,modelVar])/auc_BM),3))
                    auc.append(round(sklearn.metrics.roc_auc_score(df_passdueday[df_passdueday.applied_at==mon].overdue, df_passdueday[df_passdueday.applied_at==mon].loc[:,modelVar]),3))
                auc[-1] = str(auc[-1]) + '\n      AUC基准: ' + str(round(auc_BM,3)) 
                plotLiftChart(modelType[modelList.index(modelVar)] + '-' + appliedType_type[appliedType] + '-' + channel[channelID] + ' AUC WITH '+ str(passdueday) + '+',y_list,dateList,aucri,auc,rows,cols,value_tab,path)

            except Exception as e:  # ZeroDivisionError
                print('val exception',e)

    
def plot_table_df(dataset, auc, title='untitled', X_label=None, y_label=None,
               tab_df=None, plot_tab=True, saved_path=None):
    print(tab_df)
    '''
    instructions : visualization of pivot with single dataframe
    Params :
        dataset -
        auc - auc list / array
        title - title of plot('untitled' as default)
        x_label - X axis label of plot
        y_label - y axis label of plot
        plot_tab - plot table or not , default as True
        saved_path - saved path, set as None as there has no download needs
    '''
    fig, axs = plt.subplots(1, 1, figsize=(13, 9), linewidth=0.1)
    
    table_rows = dataset.columns
    table_cols = pd.Series(dataset.index).astype(str).map(lambda x : x.replace(' ','')).map(lambda x : x.replace('0.','.'))

    # traverse each columns of dataframe
    for i in range(len(table_rows)):
        x = range(len(table_cols))
        y = dataset.iloc[:,i]
        axs.plot(x, y, label = str(table_rows[i]) + ' AUC: ' + str(auc[i]))

            
    # if table should be plot    
    if plot_tab:
        if tab_df == None:
  
            tab_df = [list(dataset.iloc[:, 1].values) for i in range(len(table_rows))]           
        else:
            table_rows = tab_df.columns
            table_cols = tab_df.index
            tab_df = [list(tab_df.iloc[:, 1].values) for i in range(len(table_rows))]
        the_table = plt.table(cellText = tab_df,
                              rowLabels = table_rows,
                              colLabels = table_cols,
                              colWidths = [0.91 / (len(table_cols) - 1)] * len(table_cols),
                              loc='bottom')
    plt.xticks([])
    # otherwise, nothing to do here
    the_table.auto_set_font_size(False)
    the_table.set_fontsize(9)
    fig.subplots_adjust(bottom=0.2)
    plt.grid()
    if y_label is not None:
        plt.ylabel(y_label)
    if X_label is not None:
        plt.xlabel(X_label)
    plt.legend()
    # plt.vlines(xrange(len(cols))0],y,color='lightgrey',linestyle='--')
    plt.title(title)
    if saved_path != None:
        plt.savefig(saved_path + title + ".png")
    plt.show()
    return 1
    
    
    
    
    
