# Author: Qian Liu - George Masion University
# Email: qliu6@gmail.com

# import common libs
import os, sys
import numpy as np
import h5py
from scipy import io as scipyIO
from netCDF4 import Dataset
from sklearn.metrics import mean_squared_error
import scipy.stats as st
# import keras libs
from keras.models import Sequential
from keras.layers import Dense
from keras.layers import Dropout
from keras.optimizers import Adam
from tensorflow.keras import models
from keras.models import model_from_json, load_model
from keras import utils
#utils.to_categorical(...)
#from keras.utils import np_utils
#from keras.wrappers.scikit_learn import KerasClassifier
#from scikeras.wrappers import KerasClassifier
#from keras.wrappers.scikit_learn import KerasRegressor
from keras.callbacks import ModelCheckpoint
# Custom activation function
from keras.layers import Activation
from keras import backend as K
#from keras.utils.generic_utils import get_custom_objects
import tensorflow as tf
#from wrappers.scikeras import KerasClassifier
#from wrappers.scikeras import KerasRegressor
#import scikeras.wrappers.KerasRegressor
#from tensorflow.keras.wrappers.scikit_learn import KerasRegressor
from scikeras.wrappers import KerasClassifier, KerasRegressor
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
# import sklearn libs
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from shutil import copyfile
import glob, os, sys
# import netCDF4
import math
import scipy.interpolate
import scipy.ndimage
import matplotlib as mpl
import scipy.stats as stats
import matplotlib.mlab as mlab
from scipy.io import netcdf
from scipy.stats import lognorm
from scipy.stats import gamma
from scipy.stats import chisquare
# from compiler.ast import flatten
from sklearn import datasets, linear_model
from sklearn.linear_model import LinearRegression
from scipy.stats import norm
# from netCDF4 import Dataset
from scipy.interpolate import griddata
# from mpl_toolkits.basemap import Basemap,addcyclic, shiftgrid,cm
import datetime as dt
from matplotlib.ticker import MaxNLocator
import matplotlib.pyplot as plt
from matplotlib.colors import BoundaryNorm
import pandas as pd
import os
from matplotlib import cm
from matplotlib.colors import ListedColormap, LinearSegmentedColormap
from matplotlib import cm
np.random.seed(0)
import matplotlib.ticker
import h5py
from matplotlib.ticker import PercentFormatter
from scipy.stats import gaussian_kde
import seaborn as sns
from numpy import loadtxt
from sklearn.metrics import r2_score
from scipy.stats import sem
from sklearn.metrics import mean_squared_error
from mpl_toolkits.basemap import Basemap
from sklearn.preprocessing import StandardScaler
#from tensorforce import Agent, Environment
from tensorflow.keras import backend as K
from imblearn.over_sampling import SMOTE
from sklearn.utils import class_weight
from tensorflow.keras import layers, models
from netCDF4 import Dataset
from tensorflow.keras import layers
import os
import keras_tuner as kt
from tensorflow import keras
import shutil

tf.compat.v1.enable_eager_execution()
print("Eager execution enabled: ", tf.executing_eagerly())
matplotlib.use('Agg')

# [0-1]
def custom_activation(x):
    return tf.minimum(tf.maximum(x, 0.0), 1.0)
class ReLU01(Activation):
    def __init__(self, activation, **kwargs):
        super(ReLU01, self).__init__(activation, **kwargs)
        self.__name__ = 'relu01'
def relu01(x):
    # Your activation function specialties here
    return tf.minimum(tf.maximum(x, 0.0), 1.0)
def get_files(dir, ext):
    allfiles = []
    os.chdir(dir)
    for file in glob.glob(ext):
        allfiles.append(file)
    return allfiles, len(allfiles)
def is_point_in_box(lat, lon, min_lat, max_lat, min_lon, max_lon):
    # Check if the point is inside the box
    return min_lat <= lat <= max_lat and min_lon <= lon <= max_lon

def check_any_point_in_box(points, min_lat, max_lat, min_lon, max_lon):
    # Loop through the list of points, and check if any point is inside the box
    for lat, lon in points:
        if is_point_in_box(lat, lon, min_lat, max_lat, min_lon, max_lon):
            return True  # If any point is inside the box, return True
    return False  # If no points are inside the box, return False
def extract_h5_by_name(filename, dsname):
    h5_data = h5py.File(filename)
    ds = np.array(h5_data[dsname][:])
    h5_data.close()
    return ds
def get_subfiles(dir, ext):
    allfiles = []
    for root, dirs, files in os.walk(dir):
        for file in files:
            if (file.endswith(ext)):
                allfiles.append(file)
    return allfiles, len(allfiles)

def get_startfiles(dir, ext):
    allfiles = []
    for root, dirs, files in os.walk(dir):
        for file in files:
            if (file.startswith(ext)):
                allfiles.append(file)
    return allfiles, len(allfiles)

def extract_h4_by_name(filename, dsname):
    h4_data = Dataset(filename)
    ds = np.array(h4_data[dsname][:])
    h4_data.close()
    return ds

def extract_nc3_by_name(filename, dsname):
    nc_data = netcdf.netcdf_file(filename, "r")
    ds = np.array(nc_data.variables[dsname][:])
    nc_data.close()
    return ds

def extract_h5_by_name(filename, dsname):
    h5_data = h5py.File(filename)
    ds = np.array(h5_data[dsname][:])
    h5_data.close()
    return ds
# # # #  # # # # # # # # # # # # # # # # # # # #  # # # # #
# # # #  ###### using Keras  ######  # # # # 
# # # #  # # # # # # # # # # # # # # # # # # # #  # # # # # 
# main functions
if __name__ == '__main__':
    # get_custom_objects().update({'relu1': Activation(custom_activation)})
    # get_custom_objects().update({'relu01': ReLU01(relu01)})
    # input
    sate = 'j02'
    all_stat = []
    npc = 8  # 121
    showhis = 0
    clear_thr = 0.5
    thr2 = 0.3
    print (str(thr2))
    target = 'aod'
    # print(channel_omps.shape)
    viirs_cpr_txtdir = '/data/data627/qliu/omps_aod_result/pca_viirs_to_omps/PCtxt/pc_train/'
    out_dir = '/data/data627/qliu/omps_aod_result/pca_viirs_to_omps/dnn_model/'
    file_nameidx_cldsat = '20231115_20240925_119_pca_norm_' \
                          + target +'_'+str(thr2)+'_clear'+str(clear_thr)+'_j02.nc'
    file_name = file_nameidx_cldsat
    omps_lat = extract_h4_by_name(viirs_cpr_txtdir + file_name, 'omps_lat')
    omps_lon = extract_h4_by_name(viirs_cpr_txtdir + file_name, 'omps_lon')
    omps_sunz = extract_h4_by_name(viirs_cpr_txtdir + file_name, 'omps_sunz')
    omps_satez = extract_h4_by_name(viirs_cpr_txtdir + file_name, 'omps_satez')
    omps_clearfrac = extract_h4_by_name(viirs_cpr_txtdir + file_name, 'omps_clearfrac')
    omps_aod = extract_h4_by_name(viirs_cpr_txtdir + file_name, 'omps_aod')
    omps_saai = extract_h4_by_name(viirs_cpr_txtdir + file_name, 'omps_saai')
    X_pca_norm = extract_h4_by_name(viirs_cpr_txtdir + file_name, 'X_pca_norm')
    idx_valid = np.where(~np.isnan(omps_aod) & (omps_clearfrac >= clear_thr))
    y = omps_aod[idx_valid]
    print(y.shape)
    x = np.array(X_pca_norm[idx_valid, :npc]).squeeze()
    x_train_ALL = np.array(x).astype(float)
    y_train_ALL = np.array(y).astype(float)
    print(x_train_ALL.shape,y_train_ALL.shape)
    x_train_std = np.std(x_train_ALL, axis=0).reshape((1, x.shape[-1]))
    x_train_mean = np.nanmean(x_train_ALL, axis=0).reshape((1, x.shape[-1]))
    print(x_train_std.shape)
    x_train, x_test, y_train, y_test = train_test_split(x_train_ALL, y_train_ALL, test_size=0.2, random_state=42)
    # num_classes = 1
    input_dim = x_train.shape[1]
    tuner_dir = 'tuner_dir_'+str(npc)+'/dnn_regression_tuning'
    shutil.rmtree(tuner_dir, ignore_errors=True)

    def build_regression_model(hp):
        model = keras.Sequential()
        model.add(layers.Dense(
            hp.Int('units1', min_value=64, max_value=256, step=64),  # More neurons to try
            activation=hp.Choice('activation1', ['relu', 'tanh']),
            input_shape=(input_dim,)
        ))
        model.add(layers.Dense(
            hp.Int('units2', min_value=32, max_value=128, step=32),
            activation=hp.Choice('activation2', ['relu', 'tanh'])
        ))
        model.add(layers.Dense(
            hp.Int('units3', min_value=16, max_value=64, step=16),
            activation=hp.Choice('activation3', ['relu', 'tanh'])
        ))
        model.add(layers.Dense(1))  # No activation for regression output

        model.compile(
            optimizer=hp.Choice('optimizer', ['adam', 'rmsprop']),
            loss='mean_squared_error',
            metrics=['mae']
        )
        return model
    # 3. Set up tuner
    tuner = kt.RandomSearch(
        build_regression_model,
        objective='val_mae',  # minimize validation MAE
        max_trials=10,  # Try 10 different models
        executions_per_trial=1,
        directory='tuner_dir_'+str(npc),
        project_name='dnn_regression_tuning'
    )

    # 4. Search for the best model
    early_stop = keras.callbacks.EarlyStopping(monitor='val_loss', patience=5)

    tuner.search(
        x_train, y_train,
        epochs=20,
        batch_size=32,
        validation_split=0.1,
        callbacks=[early_stop],
        verbose=1
    )

    # 5. Get the best model
    best_model = tuner.get_best_models(num_models=1)[0]

    # 6. (Optional) Fine-tune further if you want
    # history = best_model.fit(x_train, y_train, epochs=20, batch_size=32, validation_split=0.1)

    # 7. Plot loss curve if you fine-tune
    # plt.plot(history.history['loss'], label='Train Loss')
    # plt.plot(history.history['val_loss'], label='Validation Loss')
    # plt.xlabel('Epochs')
    # plt.ylabel('Loss')
    # plt.legend()
    # plt.savefig(out_dir + f'DNN_ACC_first_{npc}_' + file_nameidx_cldsat + '.png', bbox_inches='tight', dpi=200)

    # 8. Evaluate
    test_loss, test_mae = best_model.evaluate(x_test, y_test)
    print(f"Test Loss: {test_loss:.4f}, Test MAE: {test_mae:.4f}")

    # 9. Save model
    model_name = file_nameidx_cldsat + '_dnn_' + target + '_model_first' + str(npc) + 'pc_j02_' + str(
        thr2) + '.h5'
    best_model.save(os.path.join(out_dir, model_name))

    # 10. Print best hyperparameters
    best_hps = tuner.get_best_hyperparameters(num_trials=1)[0]
    print(f"""
    Best Hyperparameters:
    units1: {best_hps.get('units1')}
    units2: {best_hps.get('units2')}
    units3: {best_hps.get('units3')}
    activation1: {best_hps.get('activation1')}
    activation2: {best_hps.get('activation2')}
    activation3: {best_hps.get('activation3')}
    optimizer: {best_hps.get('optimizer')}
    """)
    acc_file = os.path.join(out_dir, model_name + ".txt")
    with open(acc_file, "w") as f:
        f.write(f"Test loss: {loss:.6f}\n")
        f.write(f"Test Accuracy: {accuracy:.6f}\n")
        f.write("The hyperparameter search is complete.\n")
        f.write(f"Best units1: {best_hps.get('units1')}\n")
        f.write(f"Best units2: {best_hps.get('units2')}\n")
        f.write(f"Best units3: {best_hps.get('units3')}\n")
        f.write(f"Best activation1: {best_hps.get('activation1')}\n")
        f.write(f"Best activation2: {best_hps.get('activation2')}\n")
        f.write(f"Best activation3: {best_hps.get('activation3')}\n")  # new
        f.write(f"Best optimizer: {best_hps.get('optimizer')}\n")


    '''
    def build_dnn_model(input_shape):
        model = models.Sequential([
            layers.Dense(128, activation='relu', input_shape=(input_shape,)),  # First hidden layer with 128 neurons
            layers.Dense(64, activation='relu'),  # Second hidden layer with 64 neurons
            layers.Dense(32, activation='relu'),  # Third hidden layer with 32 neurons
            layers.Dense(1)  # Output layer for regression (single continuous value)
        ])
        return model
    model = build_dnn_model(input_dim)
    # Compile the model with a loss function and optimizer for regression
    model.compile(optimizer='adam', loss='mean_squared_error', metrics=['mae'])  # MAE: Mean Absolute Error
    # Print the model architecture
    model.summary()
    history = model.fit(x_train, y_train, epochs=20, batch_size=32, validation_split=0.1)
    val_loss_list = history.history['val_loss']
    # Calculate the mean validation loss
    mean_val_loss = np.mean(val_loss_list)
    print(f"Mean validation loss: {mean_val_loss}")
    plt.plot(history.history['loss'], label='train loss')
    plt.plot(history.history['val_loss'], label='validation loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.savefig(out_dir + 'DNN_ACC_first_' + str(npc) + '_' + file_nameidx_cldsat +
                '.png', bbox_inches='tight', dpi=200)
    test_loss, test_mae = model.evaluate(x_test, y_test)
    print(f"Test Loss: {test_loss}, Test MAE: {test_mae}")
    model_name = file_nameidx_cldsat + '_dnn_' + target + '_model_first' + str(npc) + 'pc_' + sate + '_' + str(
        thr2) + '.h5'
    model.save(out_dir + model_name)
    '''
    '''
    def build_regression_model(hp):
        model = keras.Sequential()
        model.add(layers.Dense(
            hp.Int('units1', min_value=64, max_value=256, step=64),  # More neurons to try
            activation=hp.Choice('activation1', ['relu', 'tanh']),
            input_shape=(input_dim,)
        ))
        model.add(layers.Dense(
            hp.Int('units2', min_value=32, max_value=128, step=32),
            activation=hp.Choice('activation2', ['relu', 'tanh'])
        ))
        model.add(layers.Dense(
            hp.Int('units3', min_value=16, max_value=64, step=16),
            activation=hp.Choice('activation3', ['relu', 'tanh'])
        ))
        model.add(layers.Dense(1))  # No activation for regression output

        model.compile(
            optimizer=hp.Choice('optimizer', ['adam', 'rmsprop']),
            loss='mean_squared_error',
            metrics=['mae']
        )
        return model
    # 3. Set up tuner
    tuner = kt.RandomSearch(
        build_regression_model,
        objective='val_mae',  # minimize validation MAE
        max_trials=10,  # Try 10 different models
        executions_per_trial=1,
        directory='tuner_dir',
        project_name='dnn_regression_tuning'
    )

    # 4. Search for the best model
    early_stop = keras.callbacks.EarlyStopping(monitor='val_loss', patience=5)

    tuner.search(
        x_train, y_train,
        epochs=20,
        batch_size=32,
        validation_split=0.1,
        callbacks=[early_stop],
        verbose=1
    )

    # 5. Get the best model
    best_model = tuner.get_best_models(num_models=1)[0]

    # 6. (Optional) Fine-tune further if you want
    # history = best_model.fit(x_train, y_train, epochs=20, batch_size=32, validation_split=0.1)

    # 7. Plot loss curve if you fine-tune
    # plt.plot(history.history['loss'], label='Train Loss')
    # plt.plot(history.history['val_loss'], label='Validation Loss')
    # plt.xlabel('Epochs')
    # plt.ylabel('Loss')
    # plt.legend()
    # plt.savefig(out_dir + f'DNN_ACC_first_{npc}_' + file_nameidx_cldsat + '.png', bbox_inches='tight', dpi=200)

    # 8. Evaluate
    test_loss, test_mae = best_model.evaluate(x_test, y_test)
    print(f"Test Loss: {test_loss:.4f}, Test MAE: {test_mae:.4f}")

    # 9. Save model
    model_name = file_nameidx_cldsat + '_dnn_' + target + '_model_first' + str(npc) + 'pc_' + sate + '_' + str(
        thr2) + '.h5'
    best_model.save(os.path.join(out_dir, model_name))

    # 10. Print best hyperparameters
    best_hps = tuner.get_best_hyperparameters(num_trials=1)[0]
    print(f"""
    Best Hyperparameters:
    units1: {best_hps.get('units1')}
    units2: {best_hps.get('units2')}
    units3: {best_hps.get('units3')}
    activation1: {best_hps.get('activation1')}
    activation2: {best_hps.get('activation2')}
    activation3: {best_hps.get('activation3')}
    optimizer: {best_hps.get('optimizer')}
    """)

    '''
    '''
    def build_dnn_model(input_shape):
        model = models.Sequential([
            layers.Dense(128, activation='relu', input_shape=(input_shape,)),  # First hidden layer with 128 neurons
            layers.Dense(64, activation='relu'),  # Second hidden layer with 64 neurons
            layers.Dense(32, activation='relu'),  # Third hidden layer with 32 neurons
            layers.Dense(1)  # Output layer for regression (single continuous value)
        ])
        return model
    model = build_dnn_model(input_dim)
    # Compile the model with a loss function and optimizer for regression
    model.compile(optimizer='adam', loss='mean_squared_error', metrics=['mae'])  # MAE: Mean Absolute Error
    # Print the model architecture
    model.summary()
    history = model.fit(x_train, y_train, epochs=20, batch_size=32, validation_split=0.1)
    val_loss_list = history.history['val_loss']
    # Calculate the mean validation loss
    mean_val_loss = np.mean(val_loss_list)
    print(f"Mean validation loss: {mean_val_loss}")
    plt.plot(history.history['loss'], label='train loss')
    plt.plot(history.history['val_loss'], label='validation loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.savefig(out_dir + 'DNN_ACC_first_'+str(npc)+'_' + file_nameidx_cldsat +
                '.png', bbox_inches='tight', dpi=200)
    test_loss, test_mae = model.evaluate(x_test, y_test)
    print(f"Test Loss: {test_loss}, Test MAE: {test_mae}")
    model_name = file_nameidx_cldsat+'_dnn_'+target+'_model_first'+str(npc)+'pc_'+sate+'_'+str(thr2)+'.h5'
    model.save(out_dir  + model_name)
    '''
