from pyspark.sql import DataFrame
from pyspark.sql.functions import *
from pyspark.sql.types import *
from pyspark.sql.window import Window
import pandas as pd
import random
import logging
from time import sleep
from rediscluster import RedisCluster
from datetime import datetime
from pyspark.sql import SparkSession
spark = SparkSession \
    .builder \
    .master("yarn") \
    .appName("write_product_similarity_redis.py") \
    .config("spark.sql.warehouse.dir", "some-value") \
    .config('spark.executor.memory', '8g') \
    .config("spark.dynamicAllocation.enabled",'true') \
    .config("spark.dynamicAllocation.shuffleTracking.enabled",'true') \
    .config("spark.shuffle.service.enabled",'true') \
    .config("spark.dynamicAllocation.maxExecutors", 15) \
    .config("spark.dynamicAllocation.minExecutors",8) \
    .config("spark.executor.cores", "4") \
    .enableHiveSupport() \
    .getOrCreate()

spark.conf.set("spark.sql.crossJoin.enabled", "true")
spark.conf.set("spark.sql.execution.arrow.enabled", "true")
spark.conf.set("spark.sql.sources.partitionOverwriteMode","dynamic")
spark.conf.set("hive.exec.dynamic.partition.mode","nonstrict")

from sqlalchemy import create_engine
online_xyqb_recommender_system = create_engine('mysql+pymysql://xyqb_recommender:vxVFCgWTKjYb0xfR@rm-2ze1l8mi94dkd255c.mysql.rds.aliyuncs.com:3306/xyqb_recommender_system?charset=utf8', echo=False)

import warnings
warnings.filterwarnings('ignore')


write_db_start_time = datetime.now()
print("相似度表结果开始入库，开始时间：{}".format(write_db_start_time))
print('')

def conredis_test(config):
    redis_nodes = config["REDIS_NODES"]
    redis_expiretime=config["REDIS_EXPIRETIME"] #redis的key的过期时间,单位s
    maxconnections=config["REDIS_MAX_CONNECTIONS"]
    redis_password=config["REDIS_PASSWD"]
    redisClient=None
    try:
        # redisClient = RedisCluster(startup_nodes = redis_nodes, max_connections=maxconnections)
        redisClient = RedisCluster(startup_nodes = redis_nodes, max_connections=maxconnections, password=redis_password)
    except Exception as e:
        # app.logger.error("redis cluster connect error,redis_nodes:"+str(redis_nodes))
        redisClient=None
    return redisClient


def conredis_official(config):
    redis_nodes = config["REDIS_NODES"]
    redis_expiretime = config["REDIS_EXPIRETIME"]  # redis的key的过期时间,单位s
    maxconnections = config["REDIS_MAX_CONNECTIONS"]
    redis_password = config.get("REDIS_PASSWD", None)
    redisClient = None
    try:
        print("======")
        redisClient = RedisCluster(
            startup_nodes = redis_nodes,
            max_connections=maxconnections,
            skip_full_coverage_check = True
        )
    except Exception as e:
        redisClient = None
    return redisClient


def similarity_product_2_redis(df: DataFrame, num_partitions=2, batch_size=50000, sleep_secs=0.2):
    """
    将 spark sql 的 DataFrame 数据分布批量写入 Redis
    ``df`` 要写入的 spark DataFrame 对象数据。
    ``num_partitions`` int 类型，指定数据的分区数，默认为 2。
    ``batch_size`` int 类型，指定每个 worker 批量执行的数据条数，默认为 500。
    ``sleep_secs`` float 类型，指定每个 worker 批量执行数据后的睡眠时间，单位为秒，默认为 0.2。
    """

    # count = df.count()
    # if count == 0:
    #     print('df count 0')
    #     return
    col_names = df.columns

    def _save_2_redis(ite):

        config_official_ali = {
            'REDIS_NODES': [
                {'host': 'r-2ze3ferg4oc0tuqeou.redis.rds.aliyuncs.com', 'port': 6379},
            ],
            'REDIS_EXPIRETIME': 24 * 7 * 3600,
            'REDIS_MAX_CONNECTIONS': 50,
            #     'REDIS_PASSWD': 'redis',
        }
        # redis_expiretime_official = config_official["REDIS_EXPIRETIME"]
        redisClient_official_ali = conredis_official(config_official_ali)

        pipe_official_ali = redisClient_official_ali.pipeline(transaction=False)
        # 1 天有效期
        ex_official_ali = config_official_ali['REDIS_EXPIRETIME']


        idx = 1
        for i, row in enumerate(ite, start=1):
            idx = i
            redis_key = 'product_similarity:' + str(row.product_id1)
            fv = {}
            for x in col_names:
                if x != 'product_id2_list':
                    continue
                val = getattr(row, x, None)
                if val is not None:
                    pipe_official_ali.ltrim(redis_key, -1, 0)
                    for i in list(val):
                        pipe_official_ali.rpush(redis_key, i)
                    #                         fv[x] = val
                    #                         pipe.hmset(redis_key, fv)
                    pipe_official_ali.expire(redis_key, ex_official_ali)
            if i % batch_size == 0:
                pipe_official_ali.execute()
                # 批量提交后，睡眠指定时间，控制写入频率
                sleep(sleep_secs)
        if idx % batch_size != 0:
            # 最后不够 batch_size 的数据批量提交
            pipe_official_ali.execute()

    # 分布批量写入
    df.repartition(num_partitions).foreachPartition(_save_2_redis)


############################################ 01 获取相似度表数据 ############################################
sql = '''
select product_id1, product_id2, similarity_rank
from mix_products_similarity
where similarity_rank <= 100
-- and valid_code=1
-- limit 100
'''
mix_prds_sim = pd.read_sql(sql, con=online_xyqb_recommender_system)
print('商品数：', len(mix_prds_sim))
mix_prds_sim_df = spark.createDataFrame(mix_prds_sim)

grouped = mix_prds_sim_df.groupBy("product_id1").agg(collect_list(struct("similarity_rank", "product_id2")).alias("tmp"))
grouped2 = grouped.select("product_id1", sort_array("tmp")["product_id2"].alias("product_id2_list"))
similarity_product_2_redis(grouped2, num_partitions=200)

# 近线层redis写相似表 - Ali