#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Wed Sep 15 11:18:31 2021

@author: gyin

This code is written to conduct CDF/CDFt to ML results

09/15/2021

"""
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from statsmodels.distributions.empirical_distribution import ECDF
import netCDF4 as nc
from netCDF4 import Dataset
import time

# =============================================================================
# def quantile_correction(obs_data, mod_data, sce_data, modified=True):
#     cdf = ECDF(mod_data)  #cdf of historical ML results
#     p = cdf(sce_data) * 100 #prob of ML in historical ML data
#     
#     # find the percentage value in historical obs and ML and calculate the difference (this is the values to be adjusted)
#     cor = np.subtract(*[np.nanpercentile(x, p) for x in [obs_data, mod_data]]) 
#     
#     if modified:
#         mid = np.subtract(*[np.nanpercentile(x, 50) for x in [obs_data, mod_data]])
#         g = np.true_divide(*[np.nanpercentile(x, 50) for x in [obs_data, mod_data]])
# 
#         iqr_obs_data = np.subtract(*np.nanpercentile(obs_data, [75, 25]))
#         iqr_mod_data = np.subtract(*np.nanpercentile(mod_data, [75, 25]))
# 
#         f = np.true_divide(iqr_obs_data, iqr_mod_data)
#         cor = g * mid + f * (cor - mid)
#         return sce_data + cor
#     else:
#         return sce_data + cor
#     
# =============================================================================
    
    
# This version does not include zeros in calculating distribution
    
def quantile_correction(obs_data, mod_data, sce_data):
    
    
    # find obs/mod values larger than zero
    ind_obs = ~np.isnan(obs_data)
    ind_mod = ~np.isnan(mod_data)
    
    tmp_obs = obs_data[ind_obs]
    tmp_mod = mod_data[ind_mod]
    
    tmp_obs = np.where(tmp_obs<0.1,0,tmp_obs)
    tmp_mod = np.where(tmp_mod<0.1,0,tmp_mod)
    
    sel_obs = tmp_obs[tmp_obs>0]
    sel_mod = tmp_mod[tmp_mod>0]
    
    # find sce data larger than zero
    ind_sce = ~np.isnan(sce_data)
    sce_data[ind_sce] = np.where(sce_data[ind_sce]<0.1,0,sce_data[ind_sce])
    sel_ind = sce_data>0
    
    # conduct CDF when not all forecasted precipitation are zeros
    if np.nansum(sel_ind)>0:

        cdf = ECDF(sel_mod)  #cdf of historical ML results
        p = cdf(sce_data[sel_ind]) * 100 #prob of ML in historical ML data
        
        # find the percentage value in historical obs and ML and calculate the difference (this is the values to be adjusted)
        cor = np.subtract(*[np.nanpercentile(x, p) for x in [sel_obs, sel_mod]]) 
        
        new_sce = sce_data
        new_sce[sel_ind] = sce_data[sel_ind]+cor
        return new_sce
    else:
        new_sce = sce_data
        return new_sce
    


start_time = time.time()

# define study period
allyears = [2007,2008,2009,2010,2011,2012,2013,2014,2015,2016,2017,2018,2019]
testyears =[2020]
month = 1


nyears = 13
ntime = 744
fcst_time = 39

start_day = 1
end_day = 31
hours = [0,3,6,9,12,15,18,21]


# define feature vector size (nums=(2*rg+1)^2)
rg=12
nums = 625

# define hyperparameters (already tuned values)
gm = 0.000005
c2 = 4
ep = 0.001

# define site 
site = 'siteC'

if site == 'siteA':
    nlon = 160
    nlat = 120
    north = 46.38
    south = 39.2
    west = 138.0
    east = 147.58
    
    # define tested region
    xnum1 = 30 #lon
    xnum2 = 70
    ynum1 = 35 #lat
    ynum2 = 80
    
elif site == 'siteB':
    nlon = 150
    nlat = 140
    north = 40.95
    south = 32.6
    west = 134.4
    east = 143.38
    
    
    xnum1 = 40 #lon
    xnum2 = 80
    ynum1 = 60 #lat
    ynum2 = 105
    
elif site == 'siteC':
    
    nlon = 140
    nlat = 140 
    north = 37.35
    south = 29
    east = 136.15
    west = 127.8
    
    xnum1 = 13 #lon
    xnum2 = 128
    ynum1 = 13 #lat
    ynum2 = 128
    
elif site == 'siteD':
    nlon = 110
    nlat = 112
    
    north = 30.9
    south = 24.2
    west = 126.0
    east = 132.55
    
    xnum1 = 50 #lon
    xnum2 = 85
    ynum1 = 35 #lat
    ynum2 = 75
    
elif site == 'siteE':
    nlon = 100
    nlat = 80
    
    north = 27.75
    south = 23.0
    west = 121.8
    east = 127.75
    
    xnum1 = 40 #lon
    xnum2 = 70
    ynum1 = 30 #lat
    ynum2 = 65
    
else:
    raise Exception('site setting is wrong, error!!!')


# define data path
# cross_path = '/Users/gyin/Documents/Research/ML/p1/cross07/save_nc'    # path of cross validation results
# ML_path = '/Users/gyin/Documents/Research/ML/p1/fcst_39hr/save_nc/cdf_cdft'   # path of 39 hour forecast SVM results 
# save_path = '/Users/gyin/Documents/Research/ML/p1/fcst_39hr/save_nc/test_python_cdf'  # path to save 39 hour forecast ML-CDF results

cross_path = './save_nc_cross'    # path of cross validation results
ML_path = './save_nc_39hr_fcst'   # path of 39 hour forecast SVM results 
save_path = './save_nc_39hr_cdf'  # path to save 39 hour forecast ML-CDF results




# =============================================================================
# # Load cross-validation results for building distribution function
# =============================================================================

obs_mat = np.empty((nlon,nlat))
obs_mat[:] = np.nan

svm_mat = np.empty((nlon,nlat))
svm_mat[:] = np.nan


for year in allyears:
    
    print("loading year="+str(year))
    
    
    f1 = cross_path+"/ML_hourly_"+str(year)+str(month).zfill(2)+".nc"
    #f1 = cross_path+"/hourly_"+str(year)+str(month).zfill(2)+".nc"

    tmp_obs = Dataset(f1,'r').variables['obs'][0:nlon,0:nlat,0:ntime]
    tmp_svm = Dataset(f1,'r').variables['svm'][0:nlon,0:nlat,0:ntime]
    
    obs_mat = np.dstack((obs_mat,tmp_obs))
    svm_mat = np.dstack((svm_mat,tmp_svm))
    
    del tmp_obs,tmp_svm,f1
            
obs_mat = np.delete(obs_mat,0,axis=2)   
svm_mat = np.delete(svm_mat,0,axis=2)







# =============================================================================
#  # ===========================  start conducting CDF ========================
# =============================================================================

for current_year in testyears:
    
    for current_day in range(start_day,end_day+1):
        
        for current_hr in hours:
        
            print('processing '+str(current_year)+'-'+str(month).zfill(2)+'-'+str(current_day).zfill(2)+' '+str(current_hr).zfill(2)+':00')
            
            current_ymdh = str(current_year)+str(month).zfill(2)+str(current_day).zfill(2)+str(current_hr).zfill(2)
            
            # load 39hour forecast SVM results
            f2 = ML_path + "/ML_39hr_fcst_"+current_ymdh+".nc"
            
            current_svm = Dataset(f2,'r').variables['svm'][0:nlon,0:nlat,0:ntime]
            current_msm = Dataset(f2,'r').variables['msm'][0:nlon,0:nlat,0:ntime]


            # define the mlcdf_mat to save cdf results for each prediction hour
            mlcdf_mat = np.empty((nlon,nlat,fcst_time))
            mlcdf_mat[:] = np.nan
            
            
    
            for ixx in range(xnum1,xnum2+1):
                for iyy in range(ynum1,ynum2+1):
                   
                    ix = ixx-1
                    iy = iyy-1
                    
                    obs_data = obs_mat[ix,iy,:]
                    svm_data = svm_mat[ix,iy,:]
                    
                    # I dont completely understand it, but this definition of valid_svm is needed
                    # otherwise, not assign values but pointing valid_svm to current_svm (current_svm values will change after CDF)
                    valid_svm = np.empty((fcst_time))
                    valid_svm[:] = np.nan
                    
                    valid_svm[:] = current_svm[ix,iy,0:fcst_time]
                    
                    
                    
                
                    # if all NaN in obs or svm, skip the grid
                    if (np.nansum(~np.isnan(obs_data))<10) or (np.nansum(~np.isnan(svm_data))<10) or (np.nansum(obs_data)==0) or (np.nansum(svm_data)==0):
                        # too few not nan data in obs/svm, skip the grid
                        continue
            

                    
                    mlcdf_mat[ix,iy,0:fcst_time] = quantile_correction(obs_data, svm_data, valid_svm)
                        
                        
                    del obs_data,svm_data,valid_svm
                    
                    
        
            fn= save_path+"/MLCDF_39hr_fcst_"+current_ymdh+".nc"
            ds = nc.Dataset(fn,'w',format='NETCDF4')
            
            ds.createDimension('nlat',nlat)
            ds.createDimension('nlon',nlon)
            ds.createDimension('ntime',fcst_time)
            
            lat1D = ds.createVariable('lat1D','f',('nlat',))
            lon1D = ds.createVariable('lon1D','f',('nlon',))
            svm = ds.createVariable('svm','f',('nlon','nlat','ntime'))
            msm = ds.createVariable('msm','f',('nlon','nlat','ntime'))
            svmcdf = ds.createVariable('svmcdf','f',('nlon','nlat','ntime'))
          
           
            
            lat1D[:] = np.arange(south,north,0.06)
            lon1D[:] = np.arange(west,east,0.06)
            svm[:] = current_svm
            msm[:] = current_msm
            svmcdf[:] = mlcdf_mat

           
            msm.units = 'mm/hr'
            msm.long_name = 'MSM-GPV'
        
            
            ds.close()   
             
            del current_svm,current_msm,mlcdf_mat
            del fn,ds,lat1D,lon1D,svm,msm,svmcdf



print("--- %s seconds ---" % (time.time() - start_time))    










