import glob, os, sys
# import netCDF4
import math
import scipy.interpolate
import scipy.ndimage
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
import scipy.stats as stats
import matplotlib.mlab as mlab
from scipy.io import netcdf
from scipy.stats import lognorm
from scipy.stats import gamma
from scipy.stats import chisquare
# from compiler.ast import flatten
from sklearn import datasets, linear_model
from sklearn.linear_model import LinearRegression
from scipy.stats import norm
# from netCDF4 import Dataset
from scipy.interpolate import griddata
# from mpl_toolkits.basemap import Basemap,addcyclic, shiftgrid,cm
import datetime as dt
from matplotlib.ticker import MaxNLocator
import matplotlib.pyplot as plt
from matplotlib.colors import BoundaryNorm
import pandas as pd
import os
import matplotlib.pyplot as plt
from matplotlib import cm
from matplotlib.colors import ListedColormap, LinearSegmentedColormap
import numpy as np
from matplotlib import cm
np.random.seed(0)
import matplotlib.pyplot as plt
import matplotlib.ticker
import h5py
from matplotlib.ticker import PercentFormatter
from scipy.stats import gaussian_kde
import seaborn as sns
from numpy import loadtxt
from sklearn.metrics import r2_score
from scipy.stats import sem
from sklearn.metrics import mean_squared_error
from mpl_toolkits.basemap import Basemap
from sklearn.decomposition import PCA, KernelPCA
from sklearn.preprocessing import StandardScaler
import random
import datetime
from netCDF4 import Dataset
from haversine import haversine, Unit

matplotlib.use('Agg')
def extract_nc3_by_name(filename, dsname):
    nc_data = netcdf.netcdf_file(filename, "r")
    ds = np.array(nc_data.variables[dsname][:])
    nc_data.close()
    return ds

def extract_h5_by_name(filename, dsname):
    h5_data = h5py.File(filename)
    ds = np.array(h5_data[dsname][:])
    h5_data.close()
    return ds

def extract_h4_by_name(filename, dsname):
    h4_data = Dataset(filename)
    ds = np.array(h4_data[dsname][:])
    h4_data.close()
    return ds

def get_files(dir, ext):
    allfiles = []
    os.chdir(dir)
    for file in glob.glob(ext):
        allfiles.append(file)
    return allfiles, len(allfiles)
def get_subfiles(dir, ext):
    allfiles = []
    for root, dirs, files in os.walk(dir):
        for file in files:
            if (file.endswith(ext)):
                allfiles.append(file)
    return allfiles, len(allfiles)

def get_startfiles(dir, ext):
    allfiles = []
    for root, dirs, files in os.walk(dir):
        for file in files:
            if (file.startswith(ext)):
                allfiles.append(file)
    return allfiles, len(allfiles)

white_viridis = LinearSegmentedColormap.from_list('white_viridis', [
    (0, '#ffffff'),
    (1e-20, '#440053'),
    (0.2, '#404388'),
    (0.4, '#2a788e'),
    (0.6, '#21a784'),
    (0.8, '#78d151'),
    (1, '#fde624'),
], N=256)

def plot_examples(cms):
    """
    helper function to plot two colormaps
    """
    np.random.seed(19680801)
    data = np.random.randn(30, 30)

    fig, axs = plt.subplots(1, 2, figsize=(6, 3), constrained_layout=True)
    for [ax, cmap] in zip(axs, cms):
        psm = ax.pcolormesh(data, cmap=cmap, rasterized=True, vmin=-4, vmax=4)
        fig.colorbar(psm, ax=ax)
    plt.show()

def get_density(x:np.ndarray, y:np.ndarray):
    """Get kernal density estimate for each (x, y) point."""
    values = np.vstack([x, y])
    kernel = stats.gaussian_kde(values)
    density = kernel(values)
    return density

class FormatScalarFormatter(matplotlib.ticker.ScalarFormatter):
    def __init__(self, fformat="%1.1f", offset=True, mathText=True):
        self.fformat = fformat
        matplotlib.ticker.ScalarFormatter.__init__(self, useOffset=offset,
                                                   useMathText=mathText)

    def _set_format(self, vmin, vmax):
        self.format = self.fformat
        if self._useMathText:
            self.format = '$%s$' % matplotlib.ticker._mathdefault(self.format)
def find_files_with_text(path, text):
    """Finds all files in a given path containing specific text."""

    file_count = 0
    matching_files = []

    for root, dirs, files in os.walk(path):
        for file in files:
            filepath = os.path.join(root, file)
            try:
                with open(filepath, 'r') as f:
                    if text in f.read():
                        matching_files.append(filepath)
                        file_count += 1
            except UnicodeDecodeError:
                pass  # Ignore non-text files

    return matching_files, file_count
def is_point_in_box(lat, lon, min_lat, max_lat, min_lon, max_lon):
    # Check if the point is inside the box
    return min_lat <= lat <= max_lat and min_lon <= lon <= max_lon

def check_any_point_in_box(points, min_lat, max_lat, min_lon, max_lon):
    # Loop through the list of points, and check if any point is inside the box
    for lat, lon in points:
        if is_point_in_box(lat, lon, min_lat, max_lat, min_lon, max_lon):
            return True  # If any point is inside the box, return True
    return False  # If no points are inside the box, return False


def is_coordinate_in_bbox(lat, lon, bbox):
    """
    Checks if a coordinate is inside a bounding box.

    Args:
    - lat (float): Latitude of the coordinate.
    - lon (float): Longitude of the coordinate.
    - bbox (tuple): Bounding box defined as (min_lat, min_lon, max_lat, max_lon).

    Returns:
    - bool: True if the coordinate is inside the bounding box, False otherwise.
    """
    min_lat, min_lon, max_lat, max_lon = bbox
    return min_lat <= lat <= max_lat and min_lon <= lon <= max_lon

###################################################################################################################
all_omps_reflec = []
all_omps_sunz = []
all_omps_satez = []
all_viirs_aod = []
all_viirs_saai = []
all_omps_lat = []
all_omps_lon = []
all_omps_cldfrac = []
all_omps_lat_fov = []
all_omps_lon_fov = []
all_viirs_m3 = []
all_viirs_m4 = []
all_viirs_m5 = []
all_viirs_rfl3 = []
all_viirs_rfl4 = []
all_viirs_rfl5 = []
all_viirs_lat = []
all_viirs_lon = []
min_lat = 15
max_lat = 60
min_lon = -135
max_lon = -60
box_cenlat = 0.5 * (min_lat + max_lat)
box_cenlon = 0.5 * (min_lon + max_lon)
box_coor = (box_cenlat, box_cenlon)
sate = 'j02'
if sate == 'j02':
    sate_cap = 'N21'
    n_viirs = 400
if sate == 'j01':
    sate_cap = 'N20'
    n_viirs = 4000
if sate == 'npp':
    sate_cap = 'SNPP'
    n_viirs = 12000
if sate == 'hrj01':
    sate_cap = 'HRN20'
    n_viirs = 1000
savepca = 0
m = 0
thr2 = 0.2
clear_thr = 0.5
target = 'aod'
if savepca == 1:
    yearthre1 = 2024
    monththre1 = 1
    daythre1 = 8
    yearthre2 = 2024
    monththre2 = 12
    daythre2 = 26
    formatted_dates = ['20240108', '20240216',
                       '20240310', '20240421',
                       '20240518','20240612',
                       '20240716','20240809',
                       '20240911', '20241017',
                       '20241123','20241226']
    #'20231130','20231115', '20231120','20231125','20240701', '20240702','20240703', '20240704','20240705', '20240715', '20240725','20240805','20240815', '20240825',
                       #'20240905', '20240915',
                       #'20240925',
    start_str = formatted_dates[0]
    end_str = formatted_dates[-1]
if savepca == 0:
    yearthre1 = 2024
    yearthre2 = yearthre1
    monththre1 = 7
    monththre2 = monththre1
    daythre1 = 21
    daythre2 = daythre1
    date1 = datetime.datetime(yearthre1, monththre1, daythre1)
    date2 = datetime.datetime(yearthre2, monththre2, daythre2)
    start_str = date1.strftime("%Y%m%d")
    end_str = date2.strftime("%Y%m%d")
    dates = pd.date_range(start=date1, end=date2)
    # Format to 'yyyymmdd'
    formatted_dates = dates.strftime("%Y%m%d").tolist()
file_nameidx = f"{start_str}_{end_str}"
out_dir = '/data/data627/qliu/omps_aod_result/pca_viirs_to_omps/'
for idate in formatted_dates:
    j = 0
    date_omps_reflec = []
    date_omps_sunz = []
    date_omps_satez = []
    date_viirs_aod = []
    date_viirs_saai = []
    date_omps_lat = []
    date_omps_lon = []
    date_cldfrac = []
    date_omps_lat_fov = []
    date_omps_lon_fov = []
    date_viirs_m3 = []
    date_viirs_m4 = []
    date_viirs_m5 = []
    date_viirs_rfl3 = []
    date_viirs_rfl4 = []
    date_viirs_rfl5 = []
    date_viirs_lat = []
    date_viirs_lon = []
    print(idate)
    if sate_cap == 'HRN20':
        viirs_cpr_txtdir = '/data/data627/qliu/google_drive/N20/COL_OMPS_VIIRS_HRN20/' + idate + '/'
    else:
        viirs_cpr_txtdir = '/data/data627/qliu/google_drive/'+sate_cap+'/COL_OMPS_VIIRS/'+idate+'/'
    file_nameidx_cldsat = '.nc'
    all_viirscpr_files, n_all_viirscpr = get_subfiles(viirs_cpr_txtdir, file_nameidx_cldsat)
    all_viirscpr_files = np.sort(all_viirscpr_files)
    RANGE = range(n_all_viirscpr)
    for i in RANGE:
        file_name = all_viirscpr_files[i]
        s_omps_year_str = file_name.split('_d')[1][:4]
        s_omps_month_str = file_name.split('_d')[1][4:6]
        s_omps_day_str = file_name.split('_d')[1][6:8]
        s_omps_hour_str = file_name.split('_t')[1][0:2]
        s_omps_min_str = file_name.split('_t')[1][2:4]
        s_omps_sec_str = file_name.split('_t')[1][4:6]
        s_omps_year = int(s_omps_year_str)
        s_omps_month = int(s_omps_month_str)
        s_omps_day = int(s_omps_day_str)
        s_omps_hour = int(s_omps_hour_str)
        s_omps_min = int(s_omps_min_str)
        s_omps_sec = int(s_omps_sec_str)

        e_omps_hour_str = file_name.split('_e')[1][0:2]
        e_omps_min_str = file_name.split('_e')[1][2:4]
        e_omps_sec_str = file_name.split('_e')[1][4:6]
        e_omps_hour = int(e_omps_hour_str)
        e_omps_min = int(e_omps_min_str)
        e_omps_sec = int(e_omps_sec_str)
        omps_s_time = datetime.datetime(s_omps_year, s_omps_month, s_omps_day,
                                        s_omps_hour, s_omps_min, s_omps_sec)
        omps_e_time = datetime.datetime(s_omps_year, s_omps_month, s_omps_day,
                                        e_omps_hour, e_omps_min, e_omps_sec)
        # stat = loadtxt(viirs_cpr_txtdir + file_name)
        omps_lat = extract_h4_by_name(viirs_cpr_txtdir + file_name, 'omps_lat')
        omps_lon = extract_h4_by_name(viirs_cpr_txtdir + file_name, 'omps_lon')

        sif_gran = omps_lat.shape
        nx_viirs = sif_gran[0]
        ny_virrs = sif_gran[1]
        center_lon = omps_lon[int(0.5 * nx_viirs), int(0.5 * ny_virrs)]
        center_lat = omps_lat[int(0.5 * nx_viirs), int(0.5 * ny_virrs)]
        if abs(center_lon) > 180 or abs(center_lat) > 90:
            continue
        gran_coor = (center_lat, center_lon)
        distance_km = haversine(gran_coor, box_coor)
        if distance_km > 6000:  # 4000
            continue
        if savepca == 0:
            omps_lat_fov = extract_h4_by_name(viirs_cpr_txtdir + file_name, 'omps_lat_fov')
            omps_lon_fov = extract_h4_by_name(viirs_cpr_txtdir + file_name, 'omps_lon_fov')
            viirs_lat = extract_h4_by_name(viirs_cpr_txtdir + file_name, 'viirs_M_lat')
            viirs_lon = extract_h4_by_name(viirs_cpr_txtdir + file_name, 'viirs_M_lon')
            viirs_M_rfl3 = extract_h4_by_name(viirs_cpr_txtdir + file_name, 'viirs_M_pyv')[:, :, 0, :].squeeze()
            viirs_M_rfl4 = extract_h4_by_name(viirs_cpr_txtdir + file_name, 'viirs_M_pyv')[:, :, 1, :].squeeze()
            viirs_M_rfl5 = extract_h4_by_name(viirs_cpr_txtdir + file_name, 'viirs_M_pyv')[:, :, 2, :].squeeze()
        omps_satZenith = extract_h4_by_name(viirs_cpr_txtdir + file_name, 'omps_satZenith')
        omps_sunZenith = extract_h4_by_name(viirs_cpr_txtdir + file_name, 'omps_sunZenith')
        omps_rad = extract_h4_by_name(viirs_cpr_txtdir + file_name, 'omps_rad')[:, :, 77:]
        idx_nanrad = np.where(omps_rad <= 0)
        omps_rad[idx_nanrad] = np.nan
        omps_solar = extract_h4_by_name(viirs_cpr_txtdir + file_name, 'omps_solar')[:, :, 77:]
        idx_nansun = np.where(omps_solar <= 0)
        omps_solar[idx_nansun] = np.nan
        omps_reflec = omps_rad / omps_solar
        idx_invalid_reflec = np.where((omps_reflec < 0) | (omps_reflec > 1))
        omps_reflec[idx_invalid_reflec] = np.nan
        sif_omps = omps_reflec.shape
        nx_omps = sif_omps[0]
        ny_omps = sif_omps[1]
        nban_omps = sif_omps[2]
        # print(omps_reflec.shape)
        viirs_M_rad3 = extract_h4_by_name(viirs_cpr_txtdir + file_name, 'viirs_M_rad')[:, :, 0, :].squeeze()
        viirs_M_rad4 = extract_h4_by_name(viirs_cpr_txtdir + file_name, 'viirs_M_rad')[:, :, 1, :].squeeze()
        viirs_M_rad5 = extract_h4_by_name(viirs_cpr_txtdir + file_name, 'viirs_M_rad')[:, :, 2, :].squeeze()
        idx_omps_out = np.where(viirs_M_rad3 < 0)
        viirs_M_rad3[idx_omps_out] = np.nan
        viirs_M_rad4[idx_omps_out] = np.nan
        viirs_M_rad5[idx_omps_out] = np.nan
        viirs_M_aod = extract_h4_by_name(viirs_cpr_txtdir + file_name, 'viirs_M_aod')
        viirs_M_aod[idx_omps_out] = np.nan
        viirs_M_aod_qc = extract_h4_by_name(viirs_cpr_txtdir + file_name, 'viirs_M_aod_qc')
        idx_invalid_aod = np.where((viirs_M_aod_qc != 0) | (viirs_M_aod > 5) | (viirs_M_aod < -100))
        viirs_M_aod[idx_invalid_aod] = np.nan
        idx_invalid_aod = np.where(np.isnan(viirs_M_aod))
        viirs_M_aod = np.nanmean(viirs_M_aod, axis=2)
        # print(viirs_M_aod.shape)
        viirs_M_SAAI = extract_h4_by_name(viirs_cpr_txtdir + file_name, 'viirs_M_SAAI')
        viirs_M_SAAI[idx_invalid_aod] = np.nan
        viirs_M_SAAI = np.nanmean(viirs_M_SAAI, axis=2)
        viirs_M_cloud = extract_h4_by_name(viirs_cpr_txtdir + file_name, 'viirs_M_cloud')
        viirs_M_cloud = viirs_M_cloud.astype(float)
        viirs_M_cloud[idx_omps_out] = np.nan
        if np.isnan(np.nanmean(viirs_M_aod)):
            continue
        thr = 0
        if savepca == 1:
            if np.nanmean(viirs_M_aod) < thr:
                continue
        omps_reflec = omps_reflec.reshape(nx_omps * ny_omps, nban_omps)
        viirs_M_cloud = viirs_M_cloud.reshape(nx_omps * ny_omps, n_viirs)
        if savepca == 0:
            omps_lat_fov = omps_lat_fov.reshape(nx_omps * ny_omps, 4)
            omps_lon_fov = omps_lon_fov.reshape(nx_omps * ny_omps, 4)
            viirs_M_rfl3 = viirs_M_rfl3.reshape(nx_omps * ny_omps, n_viirs)
            viirs_M_rfl4 = viirs_M_rfl4.reshape(nx_omps * ny_omps, n_viirs)
            viirs_M_rfl5 = viirs_M_rfl5.reshape(nx_omps * ny_omps, n_viirs)
            viirs_lat = viirs_lat.reshape(nx_omps * ny_omps, n_viirs)
            viirs_lon = viirs_lon.reshape(nx_omps * ny_omps, n_viirs)
        # nviirs_per_omps = np.sum(~np.isnan(viirs_M_cloud), axis=1)
        nviirs_per_omps = np.sum(~np.isnan(viirs_M_cloud), axis=1)
        cldviirs_per_omps = np.sum((viirs_M_cloud == 0), axis=1)
        cldfrac_peromps = cldviirs_per_omps / nviirs_per_omps
        #print(cldfrac_peromps)
        omps_sunZenith = omps_sunZenith.flatten()
        omps_satZenith = omps_satZenith.flatten()
        viirs_M_aod = viirs_M_aod.flatten()
        viirs_M_SAAI = viirs_M_SAAI.flatten()
        omps_lat = omps_lat.flatten()
        omps_lon = omps_lon.flatten()
        if j == 0:
            date_omps_reflec = omps_reflec
            date_omps_sunz = omps_sunZenith
            date_omps_satez = omps_satZenith
            date_viirs_aod = viirs_M_aod
            date_viirs_saai = viirs_M_SAAI
            date_omps_lat = omps_lat
            date_omps_lon = omps_lon
            date_cldfrac = cldfrac_peromps
            if savepca == 0:
                date_omps_lat_fov = omps_lat_fov
                date_omps_lon_fov = omps_lon_fov
                date_viirs_rfl3 = viirs_M_rfl3
                date_viirs_rfl4 = viirs_M_rfl4
                date_viirs_rfl5 = viirs_M_rfl5
                date_viirs_lat = viirs_lat
                date_viirs_lon = viirs_lon
            j = j + 1
        else:
            date_omps_reflec = np.append(date_omps_reflec, omps_reflec, axis=0)
            date_omps_sunz = np.append(date_omps_sunz, omps_sunZenith, axis=0)
            date_omps_satez = np.append(date_omps_satez, omps_satZenith, axis=0)
            date_viirs_aod = np.append(date_viirs_aod, viirs_M_aod, axis=0)
            date_viirs_saai = np.append(date_viirs_saai, viirs_M_SAAI, axis=0)
            date_omps_lat = np.append(date_omps_lat, omps_lat, axis=0)
            date_omps_lon = np.append(date_omps_lon, omps_lon, axis=0)
            date_cldfrac = np.append(date_cldfrac, cldfrac_peromps, axis=0)
            if savepca == 0:
                date_omps_lat_fov = np.append(date_omps_lat_fov, omps_lat_fov, axis=0)
                date_omps_lon_fov = np.append(date_omps_lon_fov, omps_lon_fov, axis=0)
                date_viirs_rfl3 = np.append(date_viirs_rfl3, viirs_M_rfl3, axis=0)
                date_viirs_rfl4 = np.append(date_viirs_rfl4, viirs_M_rfl4, axis=0)
                date_viirs_rfl5 = np.append(date_viirs_rfl5, viirs_M_rfl5, axis=0)
                date_viirs_lat = np.append(date_viirs_lat, viirs_lat, axis=0)
                date_viirs_lon = np.append(date_viirs_lon, viirs_lon, axis=0)
    print(np.array(date_omps_reflec).shape)
    if m == 0:
        all_omps_reflec = date_omps_reflec
        all_omps_sunz = date_omps_sunz
        all_omps_satez = date_omps_satez
        all_viirs_aod = date_viirs_aod
        all_viirs_saai = date_viirs_saai
        all_omps_lat = date_omps_lat
        all_omps_lon = date_omps_lon
        all_omps_cldfrac = date_cldfrac

        if savepca == 0:
            all_omps_lat_fov = date_omps_lat_fov
            all_omps_lon_fov = date_omps_lon_fov
            all_viirs_rfl3 = date_viirs_rfl3
            all_viirs_rfl4 = date_viirs_rfl4
            all_viirs_rfl5 = date_viirs_rfl5
            all_viirs_lat = date_viirs_lat
            all_viirs_lon = date_viirs_lon
        m = m + 1
    else:
        all_omps_reflec = np.append(all_omps_reflec, date_omps_reflec, axis=0)
        all_omps_sunz = np.append(all_omps_sunz, date_omps_sunz, axis=0)
        all_omps_satez = np.append(all_omps_satez, date_omps_satez, axis=0)
        all_viirs_aod = np.append(all_viirs_aod, date_viirs_aod, axis=0)
        all_viirs_saai = np.append(all_viirs_saai, date_viirs_saai, axis=0)
        all_omps_lat = np.append(all_omps_lat, date_omps_lat, axis=0)
        all_omps_lon = np.append(all_omps_lon, date_omps_lon, axis=0)
        all_omps_cldfrac = np.append(all_omps_cldfrac, date_cldfrac, axis=0)

        if savepca == 0:
            all_omps_lat_fov = np.append(all_omps_lat_fov, date_omps_lat_fov, axis=0)
            all_omps_lon_fov = np.append(all_omps_lon_fov, date_omps_lon_fov, axis=0)
            all_viirs_rfl3 = np.append(all_viirs_rfl3, date_viirs_rfl3, axis=0)
            all_viirs_rfl4 = np.append(all_viirs_rfl4, date_viirs_rfl4, axis=0)
            all_viirs_rfl5 = np.append(all_viirs_rfl5, date_viirs_rfl5, axis=0)
            all_viirs_lat = np.append(all_viirs_lat, date_viirs_lat, axis=0)
            all_viirs_lon = np.append(all_viirs_lon, date_viirs_lon, axis=0)
all_omps_reflec = np.array(all_omps_reflec)
all_omps_sunz = np.array(all_omps_sunz)
all_omps_satez = np.array(all_omps_satez)
all_viirs_aod = np.array(all_viirs_aod)
all_viirs_saai = np.array(all_viirs_saai)
all_omps_lat = np.array(all_omps_lat)
all_omps_lon = np.array(all_omps_lon)
all_omps_cldfrac = np.array(all_omps_cldfrac)

all_stat = np.concatenate((all_omps_lat.reshape(-1, 1),
                               all_omps_lon.reshape(-1, 1),
                               all_omps_sunz.reshape(-1, 1),
                               all_omps_satez.reshape(-1, 1),
                               all_omps_cldfrac.reshape(-1, 1),
                               all_viirs_aod.reshape(-1, 1),
                               all_viirs_saai.reshape(-1, 1),
                               all_omps_reflec
                               ),
                              axis=1)
print(all_stat.shape, all_viirs_aod.shape, all_viirs_rfl3.shape)
print(np.nanmean(all_omps_cldfrac))
idx_day = np.where(all_omps_sunz < 80)
all_stat = all_stat[idx_day[0], :]
all_omps_reflec = all_omps_reflec[idx_day[0], :]
if savepca == 0:
    all_omps_lat_fov = all_omps_lat_fov[idx_day[0], :]
    all_omps_lon_fov = all_omps_lon_fov[idx_day[0], :]
    all_viirs_rfl3 = all_viirs_rfl3[idx_day[0], :]
    all_viirs_rfl4 = all_viirs_rfl4[idx_day[0], :]
    all_viirs_rfl5 = all_viirs_rfl5[idx_day[0], :]
    all_viirs_lat = all_viirs_lat[idx_day[0], :]
    all_viirs_lon = all_viirs_lon[idx_day[0], :]
valid_rows = ~np.any(np.isnan(all_omps_reflec), axis=1)
all_omps_reflec = all_omps_reflec[valid_rows]
all_stat = all_stat[valid_rows]
all_viirs_aod = np.array(all_stat[:, 5]).squeeze()
all_omps_cldfrac = np.array(all_stat[:, 4]).squeeze()
if savepca == 0:
    all_omps_lat_fov = all_omps_lat_fov[valid_rows]
    all_omps_lon_fov = all_omps_lon_fov[valid_rows]
if savepca ==1:
    valid_mask = (~np.isnan(all_viirs_aod)) & (all_omps_cldfrac >= clear_thr)
    # Step 2: Filter all_stat
    all_stat = all_stat[valid_mask, :]  # Shape preserved: (n_valid, M)
    # Step 3: Create high and low masks from 6th column
    mask_high = all_stat[:, 5] > thr2
    mask_low = all_stat[:, 5] <= thr2
    # Step 4: Apply masks to get balanced subsets
    high_rows = all_stat[mask_high]
    low_rows = all_stat[mask_low]
    print("high_rows shape:", high_rows.shape)
    print("low_rows shape:", low_rows.shape)
    # Step 5: Randomly select equal number from each
    n_select = min(high_rows.shape[0], low_rows.shape[0])
    np.random.seed(42)
    high_sample = high_rows[np.random.choice(high_rows.shape[0], size=n_select, replace=False)]
    low_sample = low_rows[np.random.choice(low_rows.shape[0], size=n_select, replace=False)]

    # Step 6: Combine and shuffle
    all_stat = np.vstack([high_sample, low_sample])
    np.random.shuffle(all_stat)
n_reflec = 7
omps_sdr_reflec_valid = np.array(all_stat[: , n_reflec : n_reflec + 119])
sif_sample = omps_sdr_reflec_valid.shape
npc = sif_sample[1]
X = omps_sdr_reflec_valid # ndet * nsamples
print(X.shape)
if savepca == 1:
    scaler = StandardScaler()
    this_sp_rad_std = np.nanstd(omps_sdr_reflec_valid, axis=0)
    this_sp_rad_mean = np.nanmean(omps_sdr_reflec_valid, axis=0)
    X_norm = (X- this_sp_rad_mean)/ this_sp_rad_std #  nsamples*ndet
    pca = PCA(n_components=npc)
    X_pca = pca.fit_transform(X_norm)  # nsamples * npc (ndet)
    print( X_pca.shape)
    mean_pcs = np.nanmean(X_pca, axis=0)
    std_pcs = np.nanstd(X_pca, axis=0)
    print(mean_pcs.shape, std_pcs.shape)
    para_path = out_dir+'PCtxt/pc_train/'\
                + file_nameidx  + '_' + str(npc)+'_pca_parameter_' + target +'_'+str(thr2)+'_clear'+str(clear_thr)+'_j02.h5'
    with h5py.File(para_path, 'w') as h5file:
        h5file.create_dataset('components_', data=pca.components_)
        h5file.create_dataset('explained_variance_', data=pca.explained_variance_)
        h5file.create_dataset('explained_variance_ratio_', data=pca.explained_variance_ratio_)
        h5file.create_dataset('singular_values_', data=pca.singular_values_)
        h5file.create_dataset('mean_', data=pca.mean_)
        h5file.create_dataset('mean_pcs', data = mean_pcs)
        h5file.create_dataset('std_pcs', data=std_pcs)
        h5file.create_dataset('n_components_', data=pca.n_components_)
        h5file.create_dataset('mean', data=this_sp_rad_mean)
        h5file.create_dataset('std', data=this_sp_rad_std)
    X_pca_norm = (X_pca-mean_pcs)/std_pcs
    output_stat = out_dir + 'PCtxt/pc_train/' + \
                  file_nameidx + '_' + str(npc)+'_pca_norm_' + target +'_'+str(thr2)+'_clear'+str(clear_thr)+'_j02.nc'
    print("PCA parameters saved to "  + para_path)
if savepca == 0:
    # Step 1: Load PCA parameters
    para_path = out_dir + 'PCtxt/pc_train/' \
                '20231115_20240925_119_pca_parameter_' + target +'_'+str(thr2)+'_clear'+str(clear_thr)+'_j02.h5'
    with h5py.File(para_path, 'r') as h5file:
        components_ = np.array(h5file['components_'])
        explained_variance_ = np.array(h5file['explained_variance_'])
        explained_variance_ratio_ = np.array(h5file['explained_variance_ratio_'])
        singular_values_ = np.array(h5file['singular_values_'])
        mean_ = np.array(h5file['mean_'])
        std_pc = np.array(h5file['std_pcs'])
        mean_pc = np.array(h5file['mean_pcs'])
        n_components_ = int(h5file['n_components_'][()])
        mean = np.array(h5file['mean'])
        std = np.array(h5file['std'])
        print (mean.shape, std.shape)
    print("PCA model loaded successfully from " + para_path)
    # Step 2: Recreate PCA model
    X_norm = (X-mean)/ std
    X_centered = X_norm
    # Step 5: Transform the new data using the loaded PCA model
    X_pca = np.dot(X_norm, components_.T)
    X_pca_norm = (X_pca-mean_pc)/std_pc
    output_stat = out_dir + 'PCtxt/' + \
                  file_nameidx + '_' + str(npc)+'_pca_norm_' + target +'_'+str(thr2)+'_clear'+str(clear_thr)+'_' + sate + '.nc'
print(X_pca_norm.shape)
ncfile = Dataset(output_stat, mode='w', format='NETCDF4')
dim_x = ncfile.createDimension('x', X_pca_norm.shape[0])  # First dimension (e.g., latitude)
dim_y = ncfile.createDimension('y', X_pca_norm.shape[1])  # Second dimension (e.g., longitude)
dim_vx = ncfile.createDimension('vx', all_viirs_rfl3.shape[0])  # First dimension (e.g., latitude)
dim_vy = ncfile.createDimension('vy', all_viirs_rfl3.shape[1])  # Second dimension (e.g., longitude)
dim_nvpixel = ncfile.createDimension('vpixel', n_viirs)
dim_fov = ncfile.createDimension('f', 4)
data_var = ncfile.createVariable('omps_lat', np.float32, ('x'))
data_var.long_name = 'omps_lat'
data_var[:] = all_stat[:,0]
data_var = ncfile.createVariable('omps_lon', np.float32, ('x'))
data_var.long_name = 'omps_lon'
data_var[:] = all_stat[:,1]
data_var = ncfile.createVariable('omps_sunz', np.float32, ('x'))
data_var.long_name = 'omps_sunz'
data_var[:] = all_stat[:,2]
data_var = ncfile.createVariable('omps_satez', np.float32, ('x'))
data_var.long_name = 'omps_satez'
data_var[:] = all_stat[:,3]
data_var = ncfile.createVariable('omps_clearfrac', np.float32, ('x'))
data_var.long_name = 'omps_clearfrac'
data_var[:] = all_stat[:,4]
data_var = ncfile.createVariable('omps_aod', np.float32, ('x'))
data_var.long_name = 'omps_aod'
data_var[:] = all_stat[:,5]
data_var = ncfile.createVariable('omps_saai', np.float32, ('x'))
data_var.long_name = 'omps_saai'
data_var[:] = all_stat[:,6]
data_var = ncfile.createVariable('X_pca_norm', np.float32, ('x','y'))
data_var.long_name = 'X_pca_norm'
data_var[:, :] = X_pca_norm
if savepca == 0:
    data_var = ncfile.createVariable('omps_lat_fov', np.float32, ('x','f'))
    data_var.long_name = 'omps_lat_fov'
    data_var[:, :] = all_omps_lat_fov
    data_var = ncfile.createVariable('omps_lon_fov', np.float32, ('x', 'f'))
    data_var.long_name = 'omps_lon_fov'
    data_var[:, :] = all_omps_lon_fov
    data_var = ncfile.createVariable('all_viirs_rfl3', np.float32, ('vx', 'vy'))
    data_var.long_name = 'all_viirs_rfl3'
    data_var[:, :] = all_viirs_rfl3
    data_var = ncfile.createVariable('all_viirs_rfl4', np.float32, ('vx', 'vy'))
    data_var.long_name = 'all_viirs_rfl4'
    data_var[:, :] = all_viirs_rfl4
    data_var = ncfile.createVariable('all_viirs_rfl5', np.float32, ('vx', 'vy'))
    data_var.long_name = 'all_viirs_rfl5'
    data_var[:, :] = all_viirs_rfl5
    data_var = ncfile.createVariable('all_viirs_lat', np.float32, ('vx', 'vy'))
    data_var.long_name = 'all_viirs_lat'
    data_var[:, :] = all_viirs_lat
    data_var = ncfile.createVariable('all_viirs_lon', np.float32, ('vx', 'vy'))
    data_var.long_name = 'all_viirs_lon'
    data_var[:, :] = all_viirs_lon
ncfile.close()
print ('finish ' + output_stat )