#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Fri Sep 10 12:12:12 2021

@author: gyin

This code is written to run all ML steps in python

- step1 cross-validation (2007-2019)

"""

import pandas as pd
import os
import numpy as np
import sys
import matplotlib.pyplot as plt
from imblearn.under_sampling import RandomUnderSampler

from sklearn.svm import SVR
from sklearn import metrics
import joblib
import netCDF4 as nc
import time
#from sklearn.datasets import load_svmlight_file

start_time = time.time()
# define study period
#allyears = [2007, 2008, 2009]
allyears = [2007,2008,2009,2010,2011,2012,2013,2014,2015,2016,2017,2018,2019]
testyears =[2007,2008,2009,2010,2011,2012,2013,2014,2015,2016,2017,2018,2019]
month = 1 

nyears = 13
ntime = 744

# define feature vector size (nums=(2*rg+1)^2)
rg=12
nums = 625

# define hyperparameters (already tuned values)
gm = 0.000005
c2 = 4
ep = 0.001

# define site 
site = 'siteC'

if site == 'siteA':
    nlon = 160
    nlat = 120
    north = 46.38
    south = 39.2
    west = 138.0
    east = 147.58
    
    # define tested region
    xnum1 = 30 #lon
    xnum2 = 70
    ynum1 = 35 #lat
    ynum2 = 80
    
elif site == 'siteB':
    nlon = 150
    nlat = 140
    north = 40.95
    south = 32.6
    west = 134.4
    east = 143.38
    
    
    xnum1 = 40 #lon
    xnum2 = 80
    ynum1 = 60 #lat
    ynum2 = 105
    
elif site == 'siteC':
    
    nlon = 140
    nlat = 140 
    north = 37.35
    south = 29
    east = 136.15
    west = 127.8
    
    xnum1 = 13 #lon
    xnum2 = 128
    ynum1 = 13 #lat
    ynum2 = 128
    
elif site == 'siteD':
    nlon = 110
    nlat = 112
    
    north = 30.9
    south = 24.2
    west = 126.0
    east = 132.55
    
    xnum1 = 50 #lon
    xnum2 = 85
    ynum1 = 35 #lat
    ynum2 = 75
    
elif site == 'siteE':
    nlon = 100
    nlat = 80
    
    north = 27.75
    south = 23.0
    west = 121.8
    east = 127.75
    
    xnum1 = 40 #lon
    xnum2 = 70
    ynum1 = 30 #lat
    ynum2 = 65
    
else:
    raise Exception('site setting is wrong, error!!!')


# define data path
#MSM_path = '/Users/gyin/Documents/Research/ML/python_version/data/MSM'
#OBS_path = '/Users/gyin/Documents/Research/ML/python_version/data/Radar'
#classifier_path = '/Users/gyin/Documents/Research/ML/python_version/data-model'
#save_path = '/Users/gyin/Documents/Research/ML/python_version/save_nc'
MSM_path = '/data37/gyin/python_under/data_under/MSM_GPV_Rjp_cut'
OBS_path = '/data37/gyin/python_under/data_under/SRF_GPV_Ggis1km_cut'
classifier_path = './data-model-cross'
save_path = './save_nc_cross'

# ====================================================================================================
# ============================================ start loading data ====================================
# ====================================================================================================
obs_mat = np.empty((nlon,nlat))
obs_mat[:] = np.nan

msm_mat = np.empty((nlon,nlat))
msm_mat[:] = np.nan
    
for year in allyears:
    
    current_ym = str(year)+str(month).zfill(2)
    print('loading MSM and Radar data for year '+str(year)+'-'+str(month).zfill(2))
    
    # ==================================== load Radar-AMeDAS data ====================================
    tmp_obs = np.empty((nlon,nlat,ntime))
    tmp_obs[:]=np.nan
    
    
    
    tmp_f1 = open(OBS_path+'/'+site+'/'+current_ym+'/radar-'+current_ym+'.grd','r')
    dty1   = np.dtype([('data','<'+str(nlon*nlat)+'f')])
    chunk1 = np.fromfile(tmp_f1,dtype=dty1,count=ntime)

    for n in range(0,ntime):
        tmp_obs[:,:,n]=chunk1[n]['data'].reshape((nlon,nlat),order="F")
        
    tmp_obs = np.where(tmp_obs==-9.99e8,np.nan,tmp_obs)
    
    # combine data in different years
    obs_mat = np.dstack((obs_mat,tmp_obs))
    
    del tmp_f1, dty1, chunk1
    
    # ==================================== load MSM-GPV data ========================================
    tmp_msm = np.empty((nlon,nlat,ntime))
    tmp_msm[:]=np.nan
    
    
    
    tmp_f2 = open(MSM_path+'/'+site+'/'+current_ym+'/rsms-'+current_ym+'.grd','r')
    dty2   = np.dtype([('data','<'+str(nlon*nlat)+'f')])
    chunk2 = np.fromfile(tmp_f2,dtype=dty2,count=ntime)

    for n in range(0,ntime):
        tmp_msm[:,:,n]=chunk2[n]['data'].reshape((nlon,nlat),order="F")
        
    tmp_msm = np.where(tmp_msm==-9.99e8,np.nan,tmp_msm)
    
    # combine data in different years
    msm_mat = np.dstack((msm_mat,tmp_msm))
    
    del tmp_f2,dty2,chunk2
    
    
obs_mat = np.delete(obs_mat,0,axis=2)   
msm_mat = np.delete(msm_mat,0,axis=2)

# check if collected data are the same size
assert np.shape(obs_mat)==np.shape(msm_mat)

total_time = np.size(obs_mat,axis=2)

# ====================================================================================================
# ================================================ start ML ==========================================
# ====================================================================================================



# index of the first value of each year
ind_start = range(0,total_time,ntime)
ind_end   = range(ntime,total_time+1,ntime)


# start ML
for year in testyears:
    
    ML_mat = np.empty((nlon,nlat,ntime))
    ML_mat[:] = np.nan
    
    target_msm_mat = np.empty((nlon,nlat,ntime))
    target_msm_mat[:] = np.nan
    
    target_obs_mat = np.empty((nlon,nlat,ntime))
    target_obs_mat[:] = np.nan
    
# =============================================================================
#     R_mat = np.empty((nlon,nlat))
#     R_mat[:] = np.nan
#     
#     RMSE_mat = np.empty((nlon,nlat))
#     RMSE_mat[:] = np.nan
#     
#     MAE_mat = np.empty((nlon,nlat))
#     MAE_mat[:] = np.nan
# =============================================================================
        
    print('ML in year '+ str(year))
    
    count = allyears.index(year)
    
    for ixx in range(xnum1,xnum2+1):
        for iyy in range(ynum1,ynum2+1):
           
            ix = ixx-1
            iy = iyy-1
            
            
            sub_msm = np.transpose(msm_mat[ix-rg:ix+rg+1,iy-rg:iy+rg+1,:].reshape(nums,total_time)) #num_of_time * num_of_rgnums
            sub_obs = obs_mat[ix,iy,:].reshape(total_time)
            
    
            # ================== divide data into training and testing data ====================
            test_x = sub_msm[ind_start[count]:ind_end[count],:]
            test_y = sub_obs[ind_start[count]:ind_end[count]]
            
            train_x = np.delete(sub_msm,range(ind_start[count],ind_end[count]),axis=0)
            train_y = np.delete(sub_obs,range(ind_start[count],ind_end[count]))
            
            # use every three hours
            train_x = train_x[0::3,:]
            train_y = train_y[0::3] 
            
            # ============== conducting oversampling/undersampling to training data =============
            
            # remove rows with NaN values
            ind_train = ~np.isnan(train_y) & ~np.isnan(train_x).any(axis=1)

            if not any(ind_train):
                continue

            tmp_train_x = train_x[ind_train,:]
            tmp_train_y = train_y[ind_train]
            
            
            
            count_zero = sum(tmp_train_y<0.1)

            if (count_zero == np.size(tmp_train_y)) or (count_zero == np.size(tmp_train_y)/2.0):
                new_train_x = tmp_train_x
                new_train_y = tmp_train_y
                
            else:
               
                train_y_class = tmp_train_y>=0.1
                undersample = RandomUnderSampler(sampling_strategy='majority')
                
                combine_data = np.column_stack((tmp_train_x,tmp_train_y))
                new_com_data, new_y_class = undersample.fit_resample(combine_data,train_y_class)
                new_train_x = new_com_data[:,0:nums]
                new_train_y = new_com_data[:,nums]

                del combine_data, new_com_data, new_y_class, undersample,train_y_class

            # ===========================  train the model with svm =============================
            
           
    
            clf = SVR(kernel='rbf',gamma=gm, C=c2,epsilon=ep)
            clf.fit(new_train_x,new_train_y)
            
            # save classifier
            #joblib.dump(clf, classifier_path+ '/'+str(year)+"_mod1_"+ str(ixx).zfill(3)+ "_" + str(iyy).zfill(3) +".cmp", compress=True)
            #joblib.dump(clf, classifier_path+ "/mod1_"+ str(ixx).zfill(3)+ "_" + str(iyy).zfill(3) +".cmp", compress=True)
    
            
            # ===========================  apply classifier to test data ========================
            ind_test = ~np.isnan(test_x).any(axis=1)
           
            if not any(ind_test):
                continue
 
            y_pred = np.empty((ntime))
            y_pred[:] = np.nan
            
            y_pred[ind_test] = clf.predict(test_x[ind_test,:])    
        
            # let negative values be zero
            sel_ind = ~np.isnan(y_pred)
            y_pred[sel_ind] = np.where(y_pred[sel_ind]<0,0,y_pred[sel_ind])
            
            
            # save ML_file
            ML_mat[ix,iy,:] = y_pred
            target_obs_mat[ix,iy,:] = test_y
            target_msm_mat[ix,iy,:] = msm_mat[ix,iy,ind_start[count]:ind_end[count]]
            
            
            
            # compute and save statistics
            #ind_test = ~np.isnan(test_y) &  ~np.isnan(y_pred)
            
            #R_mat[ix,iy]    = np.corrcoef(test_y[ind_test],y_pred[ind_test])[0,1]
            #MAE_mat[ix,iy]  = metrics.mean_absolute_error(test_y[ind_test],y_pred[ind_test])
            #RMSE_mat[ix,iy]  = np.sqrt(metrics.mean_squared_error(test_y[ind_test],y_pred[ind_test]))
            
            del sub_msm,sub_obs,test_x,test_y,train_x,train_y,ind_train,clf,y_pred, new_train_x,new_train_y #,ind_test
            del count_zero,tmp_train_x,tmp_train_y
            
    
    

    
    fn= save_path+"/ML_hourly_"+str(year)+str(month).zfill(2)+".nc"
    ds = nc.Dataset(fn,'w',format='NETCDF4')
    
    ds.createDimension('nlat',nlat)
    ds.createDimension('nlon',nlon)
    ds.createDimension('ntime',ntime)
    
    lat1D = ds.createVariable('lat1D','f',('nlat',))
    lon1D = ds.createVariable('lon1D','f',('nlon',))
    svm = ds.createVariable('svm','f',('nlon','nlat','ntime'))
    obs = ds.createVariable('obs','f',('nlon','nlat','ntime'))
    msm = ds.createVariable('msm','f',('nlon','nlat','ntime'))
    #R_ml = ds.createVariable('R_ml','f',('nlon','nlat'))
    #RMSE_ml = ds.createVariable('RMSE_ml','f',('nlon','nlat'))
    #MAE_ml = ds.createVariable('MAE_ml','f',('nlon','nlat'))
    
    lat1D[:] = np.arange(south,north,0.06)
    lon1D[:] = np.arange(west,east,0.06)
    svm[:] = ML_mat
    obs[:] = target_obs_mat
    msm[:] = target_msm_mat
    #R_ml[:] = R_mat
    #RMSE_ml[:] = RMSE_mat
    #MAE_mat[:] = MAE_mat
    obs.units = 'mm/hr'
    obs.long_name = 'Radar-AMeDAS'
    msm.long_name = 'MSM-GPV'
    
    ds.close()   
     
    #count = count+1
            
    del fn,ds,lat1D,lon1D,svm,obs,msm
            
            
            
print("--- %s seconds ---" % (time.time() - start_time))         
            
            
