#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Tue Sep 28 12:40:18 2021

@author: gyin

This code is written for 39-hour forecast:
    
    -- step 3: run 39-hr fcst
    

"""


import pandas as pd
import os
import numpy as np
import sys
import matplotlib.pyplot as plt
from imblearn.under_sampling import RandomUnderSampler

from sklearn.svm import SVR
from sklearn import metrics
import joblib
import netCDF4 as nc
import time



# =============================================================================
# import cartopy.crs as ccrs
# import matplotlib.cm as cm
# import matplotlib as mpl
# mpl.rcParams['font.family'] = 'Arial'
# mpl.rcParams['font.size']=18
# from matplotlib.colors import LinearSegmentedColormap
# 
# def generate_cmap(colors):
#     """自分で定義したカラーマップを返す"""
#     values = range(len(colors))
# 
#     vmax = np.ceil(np.max(values))
#     color_list = []
#     for v, c in zip(values, colors):
#         color_list.append( ( v/ vmax, c) )
#     return LinearSegmentedColormap.from_list('custom_cmap', color_list)
# =============================================================================

#from sklearn.datasets import load_svmlight_file

start_time = time.time()
# define study period
#allyears = [2007, 2008, 2009]
allyears = [2007,2008,2009,2010,2011,2012,2013,2014,2015,2016,2017,2018,2019]
testyears =[2020]
month = 1

nyears = 13

fcst_time = 39

start_day = 1
end_day = 31
hours = [0,3,6,9,12,15,18,21]


# define feature vector size (nums=(2*rg+1)^2)
rg=12
nums = 625

# define hyperparameters (already tuned values)
gm = 0.000005
c2 = 4
ep = 0.001

# define site 
site = 'siteC'

if site == 'siteA':
    nlon = 160
    nlat = 120
    north = 46.38
    south = 39.2
    west = 138.0
    east = 147.58
    
    # define tested region
    xnum1 = 30 #lon
    xnum2 = 70
    ynum1 = 35 #lat
    ynum2 = 80
    
elif site == 'siteB':
    nlon = 150
    nlat = 140
    north = 40.95
    south = 32.6
    west = 134.4
    east = 143.38
    
    
    xnum1 = 40 #lon
    xnum2 = 80
    ynum1 = 60 #lat
    ynum2 = 105
    
elif site == 'siteC':
    
    nlon = 140
    nlat = 140 
    north = 37.35
    south = 29
    east = 136.15
    west = 127.8
    
    xnum1 = 13 #lon
    xnum2 = 128
    ynum1 = 13 #lat
    ynum2 = 128
    
elif site == 'siteD':
    nlon = 110
    nlat = 112
    
    north = 30.9
    south = 24.2
    west = 126.0
    east = 132.55
    
    xnum1 = 50 #lon
    xnum2 = 85
    ynum1 = 35 #lat
    ynum2 = 75
    
elif site == 'siteE':
    nlon = 100
    nlat = 80
    
    north = 27.75
    south = 23.0
    west = 121.8
    east = 127.75
    
    xnum1 = 40 #lon
    xnum2 = 70
    ynum1 = 30 #lat
    ynum2 = 65
    
else:
    raise Exception('site setting is wrong, error!!!')


# define data path
#MSM_path = '/Users/gyin/Documents/Research/ML/python_version/fcst_39hr/MSMGPV_39hr'
MSM_path = '/data37/gyin/run_fcst/MSMGPV_39hr'
classifier_path = './data-model-under'
save_path = './save_nc_39hr_fcst'


for current_year in testyears:
    
    for current_day in range(start_day,end_day+1):
        
        for current_hr in hours:
        
            print('processing '+str(current_year)+'-'+str(month).zfill(2)+'-'+str(current_day).zfill(2)+' '+str(current_hr).zfill(2)+':00')
            
            # define the matrix to be saved for current prediction hour
            
            ML_mat = np.empty((nlon,nlat,fcst_time))
            ML_mat[:] = np.nan
            
            target_msm_mat = np.empty((nlon,nlat,fcst_time))
            target_msm_mat[:] = np.nan
            
            
            # =============================================================================
            #      # load MSM-GPV 39hour fcst
            # =============================================================================
            msm_mat = np.empty((nlon,nlat,fcst_time))
            msm_mat[:]=np.nan
    
            current_ym = str(current_year)+str(month).zfill(2)
            current_ymdh = str(current_year)+str(month).zfill(2)+str(current_day).zfill(2)+str(current_hr).zfill(2)
            
            file_path = MSM_path+'/'+site+'/'+current_ym+'/'+current_ymdh+'/rsms-'+current_ymdh+'.grd'
            
            tmp_f1 = open(file_path,'r')
            dty1   = np.dtype([('data','<'+str(nlon*nlat)+'f')])
            chunk1 = np.fromfile(tmp_f1,dtype=dty1,count=fcst_time)
        
            for n in range(0,fcst_time):
                msm_mat[:,:,n]=chunk1[n]['data'].reshape((nlon,nlat),order="F")
                
            msm_mat = np.where(msm_mat==-9.99e8,np.nan,msm_mat)
            
            
        
            # =============================================================================
            # # load classifier and make prediction for each grid
            # =============================================================================

            for ixx in range(xnum1,xnum2+1):
                for iyy in range(ynum1,ynum2+1):
           
                    ix = ixx-1
                    iy = iyy-1

                    classifier_name = classifier_path+"/mod1_"+ str(ixx).zfill(3)+ "_" + str(iyy).zfill(3) +".cmp"

                    if (os.path.exists(classifier_name)):
                    	# load classifiers
                    	clf = joblib.load(classifier_path+"/mod1_"+ str(ixx).zfill(3)+ "_" + str(iyy).zfill(3) +".cmp")
                    else:
                        continue
                    # prepare testing data
                    test_x = np.transpose(msm_mat[ix-rg:ix+rg+1,iy-rg:iy+rg+1,:].reshape(nums,fcst_time)) #num_of_time * num_of_rgnums
                    
                    # ===========================  apply classifier to test data ========================
                    ind_test = ~np.isnan(test_x).any(axis=1)
                   
                    if not any(ind_test):
                        continue
         
                    y_pred = np.empty((fcst_time))
                    y_pred[:] = np.nan
                    
                    y_pred[ind_test] = clf.predict(test_x[ind_test,:])    
                
                    # let negative values be zero
                    sel_ind = ~np.isnan(y_pred)
                    y_pred[sel_ind] = np.where(y_pred[sel_ind]<0,0,y_pred[sel_ind])


                    # save ML_file
                    ML_mat[ix,iy,:] = y_pred
                    target_msm_mat[ix,iy,:] = msm_mat[ix,iy,0:fcst_time]
                    
                    del clf,test_x,ind_test,sel_ind
                    
                    
                    
                    
            # =============================================================================
            #     # for each predict hour save predictions as netcdf file
            # =============================================================================
            fn= save_path+"/ML_39hr_fcst_"+current_ymdh+".nc"
            ds = nc.Dataset(fn,'w',format='NETCDF4')
            
            ds.createDimension('nlat',nlat)
            ds.createDimension('nlon',nlon)
            ds.createDimension('ntime',fcst_time)
            
            lat1D = ds.createVariable('lat1D','f',('nlat',))
            lon1D = ds.createVariable('lon1D','f',('nlon',))
            svm = ds.createVariable('svm','f',('nlon','nlat','ntime'))
            msm = ds.createVariable('msm','f',('nlon','nlat','ntime'))
           
            lat1D[:] = np.arange(south,north,0.06)
            lon1D[:] = np.arange(west,east,0.06)
            svm[:] = ML_mat
            msm[:] = target_msm_mat
            msm.unit = 'mm/hr'
            msm.long_name = 'MSM-GPV'
            
            ds.close()   
             
            del fn,ds,lat1D,lon1D,svm,msm

            del tmp_f1,dty1,chunk1,current_ym,current_ymdh





print("--- %s seconds ---" % (time.time() - start_time))    








