# Author: Qian Liu - George Masion University
# Email: qliu6@gmail.com

# import common libs
import os, sys
import numpy as np
import h5py
from scipy import io as scipyIO
from netCDF4 import Dataset
from sklearn.metrics import mean_squared_error
import scipy.stats as st
# import keras libs
from keras.models import Sequential
from keras.layers import Dense
from keras.layers import Dropout
from keras.optimizers import Adam
from tensorflow.keras import models
from keras.models import model_from_json, load_model
from keras import utils
#utils.to_categorical(...)
#from keras.utils import np_utils
#from keras.wrappers.scikit_learn import KerasClassifier
#from scikeras.wrappers import KerasClassifier
#from keras.wrappers.scikit_learn import KerasRegressor
from keras.callbacks import ModelCheckpoint
# Custom activation function
from keras.layers import Activation
from keras import backend as K
#from keras.utils.generic_utils import get_custom_objects
import tensorflow as tf
#from wrappers.scikeras import KerasClassifier
#from wrappers.scikeras import KerasRegressor
#import scikeras.wrappers.KerasRegressor
#from tensorflow.keras.wrappers.scikit_learn import KerasRegressor
from scikeras.wrappers import KerasClassifier, KerasRegressor
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
# import sklearn libs
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from shutil import copyfile
import glob, os, sys
# import netCDF4
import math
import scipy.interpolate
import scipy.ndimage
import matplotlib as mpl
import scipy.stats as stats
import matplotlib.mlab as mlab
from scipy.io import netcdf
from scipy.stats import lognorm
from scipy.stats import gamma
from scipy.stats import chisquare
# from compiler.ast import flatten
from sklearn import datasets, linear_model
from sklearn.linear_model import LinearRegression
from scipy.stats import norm
# from netCDF4 import Dataset
from scipy.interpolate import griddata
# from mpl_toolkits.basemap import Basemap,addcyclic, shiftgrid,cm
import datetime as dt
from matplotlib.ticker import MaxNLocator
import matplotlib.pyplot as plt
from matplotlib.colors import BoundaryNorm
import pandas as pd
import os
from matplotlib import cm
from matplotlib.colors import ListedColormap, LinearSegmentedColormap
from matplotlib import cm
np.random.seed(0)
import matplotlib.ticker
import h5py
from matplotlib.ticker import PercentFormatter
from scipy.stats import gaussian_kde
import seaborn as sns
from numpy import loadtxt
from sklearn.metrics import r2_score
from scipy.stats import sem
from sklearn.metrics import mean_squared_error
from mpl_toolkits.basemap import Basemap
from sklearn.preprocessing import StandardScaler
#from tensorforce import Agent, Environment
from tensorflow.keras import backend as K
from imblearn.over_sampling import SMOTE
from sklearn.utils import class_weight
from tensorflow.keras import layers, models
from sklearn.metrics import f1_score
from sklearn.metrics import mean_absolute_error
from sklearn.metrics import mean_squared_error
from sklearn.metrics import r2_score
from netCDF4 import Dataset
from matplotlib.colors import ListedColormap
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
from scipy.stats import pearsonr

tf.compat.v1.enable_eager_execution()
print("Eager execution enabled: ", tf.executing_eagerly())
matplotlib.use('Agg')

# [0-1]
def custom_activation(x):
    return tf.minimum(tf.maximum(x, 0.0), 1.0)
class ReLU01(Activation):
    def __init__(self, activation, **kwargs):
        super(ReLU01, self).__init__(activation, **kwargs)
        self.__name__ = 'relu01'
def relu01(x):
    # Your activation function specialties here
    return tf.minimum(tf.maximum(x, 0.0), 1.0)
def get_files(dir, ext):
    allfiles = []
    os.chdir(dir)
    for file in glob.glob(ext):
        allfiles.append(file)
    return allfiles, len(allfiles)
def is_point_in_box(lat, lon, min_lat, max_lat, min_lon, max_lon):
    # Check if the point is inside the box
    return min_lat <= lat <= max_lat and min_lon <= lon <= max_lon

def check_any_point_in_box(points, min_lat, max_lat, min_lon, max_lon):
    # Loop through the list of points, and check if any point is inside the box
    for lat, lon in points:
        if is_point_in_box(lat, lon, min_lat, max_lat, min_lon, max_lon):
            return True  # If any point is inside the box, return True
    return False  # If no points are inside the box, return False
def extract_h5_by_name(filename, dsname):
    h5_data = h5py.File(filename)
    ds = np.array(h5_data[dsname][:])
    h5_data.close()
    return ds

def get_subfiles(dir, ext):
    allfiles = []
    for root, dirs, files in os.walk(dir):
        for file in files:
            if (file.endswith(ext)):
                allfiles.append(file)
    return allfiles, len(allfiles)

def get_startfiles(dir, ext):
    allfiles = []
    for root, dirs, files in os.walk(dir):
        for file in files:
            if (file.startswith(ext)):
                allfiles.append(file)
    return allfiles, len(allfiles)

def extract_h4_by_name(filename, dsname):
    h4_data = Dataset(filename)
    ds = np.array(h4_data[dsname][:])
    h4_data.close()
    return ds

# # # #  # # # # # # # # # # # # # # # # # # # #  # # # # #
# # # #  ###### using Keras  ######  # # # #
# # # #  # # # # # # # # # # # # # # # # # # # #  # # # # #
# main functions
if __name__ == '__main__':
    #get_custom_objects().update({'relu1': Activation(custom_activation)})
    #get_custom_objects().update({'relu01': ReLU01(relu01)})
    # input
    all_stat = []
    min_lat = 15
    max_lat = 67
    min_lon = -153
    max_lon = -60
    npc = 14  #121
    #min_lat = 32  # 15
    #max_lat = 35  # 67
    #min_lon = -123.6  # -153
    #max_lon = -116.7  # -60
    target = 'aod'
    sate = 'j02'
    if sate == 'j02':
        sate_cap = 'N21'
        n_viirs = 400
        viirs_cpr_txtdir = '/data/data627/qliu/omps_aod_result/pca_viirs_to_omps/PCtxt/'
    if sate == 'j01':
        sate_cap = 'N20'
        n_viirs = 4000
        viirs_cpr_txtdir = '/data/data627/qliu/omps_aod_result/pca_viirs_to_omps/PCtxt/'
    if sate == 'npp':
        sate_cap = 'SNPP'
        n_viirs = 12000
        viirs_cpr_txtdir = '/data/data627/qliu/omps_aod_result/pca_viirs_to_omps/PCtxt/'
    if sate == 'hrj01':
        sate_cap = 'HRN20'
        n_viirs = 1000
        viirs_cpr_txtdir = '/data/data627/qliu/omps_aod_result/pca_viirs_to_omps/rhn20/PCtxt/'
    savepca = 0
    clear_thr = 0.5
    thr2 = 0.2

    out_dir = '/data/data627/qliu/omps_aod_result/pca_viirs_to_omps/'
    model_dir = '/data/data627/qliu/omps_aod_result/pca_viirs_to_omps/dnn_model/'
    model_name = '20231115_20240925_119_pca_norm_aod_'+str(thr2)+'_clear'+str(clear_thr)+'_j02.nc_dnn_'+target+'_model_first'+str(npc)+'pc_j02_'+str(thr2)+'.h5'
    y_thr = 2025
    m_thr = 6
    day_thr = 9
    yearthre = y_thr
    monththre = m_thr
    daythre1 = day_thr
    daythre2 = day_thr
    file_nameidx = str(yearthre) + f"{monththre:02d}" + f"{daythre1:02d}" + '_' + str(yearthre) + f"{monththre:02d}" + f"{daythre2:02d}"
    file_nameidx_cldsat = file_nameidx+'_119_pca_norm_' + target + '_' +str(thr2)+'_clear'+str(clear_thr)+ '_'+sate + '.nc'
    print(file_nameidx_cldsat)
    file_name = file_nameidx_cldsat
    omps_lat = extract_h4_by_name(viirs_cpr_txtdir + file_name, 'omps_lat')
    omps_lon = extract_h4_by_name(viirs_cpr_txtdir + file_name, 'omps_lon')
    omps_lat_fov = extract_h4_by_name(viirs_cpr_txtdir + file_name, 'omps_lat_fov')
    omps_lon_fov = extract_h4_by_name(viirs_cpr_txtdir + file_name, 'omps_lon_fov')
    omps_sunz = extract_h4_by_name(viirs_cpr_txtdir + file_name, 'omps_sunz')
    omps_satez = extract_h4_by_name(viirs_cpr_txtdir + file_name, 'omps_satez')
    omps_clearfrac = extract_h4_by_name(viirs_cpr_txtdir + file_name, 'omps_clearfrac')
    omps_aod = extract_h4_by_name(viirs_cpr_txtdir + file_name, 'omps_aod')
    omps_saai = extract_h4_by_name(viirs_cpr_txtdir + file_name, 'omps_saai')
    X_pca_norm = extract_h4_by_name(viirs_cpr_txtdir + file_name, 'X_pca_norm')
    viirs_lat = extract_h4_by_name(viirs_cpr_txtdir + file_name, 'all_viirs_lat').flatten()
    viirs_lon = extract_h4_by_name(viirs_cpr_txtdir + file_name, 'all_viirs_lon').flatten()
    viirs_M_rfl3 = extract_h4_by_name(viirs_cpr_txtdir + file_name, 'all_viirs_rfl3').flatten()
    viirs_M_rfl4 = extract_h4_by_name(viirs_cpr_txtdir + file_name, 'all_viirs_rfl4').flatten()
    viirs_M_rfl5 = extract_h4_by_name(viirs_cpr_txtdir + file_name, 'all_viirs_rfl5').flatten()
    print(viirs_lat.shape,viirs_M_rfl3.shape)
    idx_valid =  np.where((omps_clearfrac >= clear_thr) & (omps_clearfrac <= 1) & ~np.isnan(omps_clearfrac) & ~np.isnan(omps_aod))
    #np.where(~np.isnan(omps_aod))
    # #np.where((omps_clearfrac >= 0) & (omps_clearfrac <= 1) & ~np.isnan(omps_clearfrac) & ~np.isnan(omps_aod))
    omps_lat = omps_lat[idx_valid]
    omps_lon = omps_lon[idx_valid]
    omps_lat_fov = np.array(omps_lat_fov[idx_valid,:]).squeeze()
    omps_lon_fov = np.array(omps_lon_fov[idx_valid, :]).squeeze()
    omps_clearfrac = omps_clearfrac[idx_valid]
    omps_aod = omps_aod[idx_valid]
    X_pca_norm = np.array(X_pca_norm[idx_valid,:npc]).squeeze() #idx_valid
    print(omps_lat_fov.shape)

    mask_space = (omps_lat >= min_lat) & (omps_lat <= max_lat) & \
                 (omps_lon >= min_lon) & (omps_lon <= max_lon)

    # Apply mask
    omps_lat = omps_lat[mask_space]
    omps_lon = omps_lon[mask_space]
    omps_lat_fov = omps_lat_fov[mask_space, :]  # works now
    omps_lon_fov = omps_lon_fov[mask_space, :]
    omps_clearfrac = omps_clearfrac[mask_space]
    omps_aod = omps_aod[mask_space]
    X_pca_norm = X_pca_norm[mask_space, :]
    print (X_pca_norm.shape)

    x = X_pca_norm.squeeze()
    x_train_ALL= np.array(x).astype(float)
    modeldir = model_dir + model_name  # '01072025_dnn_estimation_model.h5'
    model = load_model(modeldir)
    prediction = model.predict(x_train_ALL)
    mae = mean_absolute_error(omps_aod, prediction)
    mse = mean_squared_error(omps_aod, prediction)
    rmse = mean_squared_error(omps_aod, prediction, squared=False)
    r2 = r2_score(omps_aod, prediction)
    #mask = ~np.isnan(omps_aod) & ~np.isnan(prediction)
    #r, p_value = pearsonr(omps_aod[mask], prediction[mask])
    # f1 = f1_score(viirs_aod_valid, predictions)
    print(mae, mse, rmse, r2)
    output_stat = viirs_cpr_txtdir + file_nameidx + '_' + str(npc) + '_prediction_'+target+'_' + str(thr2) + '_' + sate + '.nc'
    ncfile = Dataset(output_stat, mode='w', format='NETCDF4')
    dim_x = ncfile.createDimension('x', prediction.shape[0])  # First dimension (e.g., latitude)
    dim_fov = ncfile.createDimension('f', 4)
    dim_vx = ncfile.createDimension('vx', viirs_M_rfl3.shape[0])  # First dimension (e.g., latitude)
    #dim_vy = ncfile.createDimension('vy', viirs_M_rfl3.shape[1])  # First dimension (e.g., latitude)
    data_var = ncfile.createVariable('omps_lat', np.float32, ('x'))
    data_var.long_name = 'omps_lat'
    data_var[:] = omps_lat
    data_var = ncfile.createVariable('omps_lon', np.float32, ('x'))
    data_var.long_name = 'omps_lon'
    data_var[:] = omps_lon
    data_var = ncfile.createVariable('omps_lat_fov', np.float32, ('x', 'f'))
    data_var.long_name = 'omps_lat_fov'
    data_var[:, :] = omps_lat_fov
    data_var = ncfile.createVariable('omps_lon_fov', np.float32, ('x', 'f'))
    data_var.long_name = 'omps_lon_fov'
    data_var[:, :] = omps_lon_fov
    data_var = ncfile.createVariable('omps_aod', np.float32, ('x'))
    data_var.long_name = 'omps_aod'
    data_var[:] = omps_aod
    data_var = ncfile.createVariable('prediction', np.float32, ('x'))
    data_var.long_name = 'prediction'
    data_var[:] = prediction
    data_var = ncfile.createVariable('all_viirs_rfl3', np.float32, ('vx'))
    data_var.long_name = 'all_viirs_rfl3'
    data_var[:] = viirs_M_rfl3
    data_var = ncfile.createVariable('all_viirs_rfl4', np.float32, ('vx'))
    data_var.long_name = 'all_viirs_rfl4'
    data_var[:] = viirs_M_rfl4
    data_var = ncfile.createVariable('all_viirs_rfl5', np.float32, ('vx'))
    data_var.long_name = 'all_viirs_rfl5'
    data_var[:] = viirs_M_rfl5
    data_var = ncfile.createVariable('all_viirs_lat', np.float32, ('vx'))
    data_var.long_name = 'all_viirs_lat'
    data_var[:] = viirs_lat
    data_var = ncfile.createVariable('all_viirs_lon', np.float32, ('vx'))
    data_var.long_name = 'all_viirs_lon'
    data_var[:] = viirs_lon
    ncfile.close()
    print('finish ' + file_nameidx)