# -*- coding: utf-8 -*-
"""
Created on Tue Jul 15 14:13:07 2025

@author: Charlotte.Kastoun
"""
#~~~~~~~~~~LIBRARIES~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

import urllib.request
import xarray as xr
import numpy as np
from matplotlib import pyplot as plt
from scipy import stats
import pandas as pd

import warnings
warnings.filterwarnings('ignore')

import seaborn as sns
from matplotlib.colors import LinearSegmentedColormap
from matplotlib.lines import Line2D


#~~~~~~~~~~EXPLANATIONS~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

# download_data: downloads the SST and CHl from ERDDAP using an ERDDAP-generated URL.
#                I had to download it in chunks of 5 years otherwise it would crash.

# create_ds: actually loads in the netcdf data into python as an xarray DataSet, 
#            then concatenates the DataSet for each chunk of 5 years into 1 ds for each variable.

# process_data: removes duplicates from the datasets and regrids chl to match the sst grid.
#               explanation of code is included in the function. returns one combined ds with both vars (sst, chl)

# examine_data: looks at the structure and variables of the DataSet

# create_basicmap: creates one map for the sst data and one for chl data at a given timestep 
#                  (preliminary visualization).
#                  adapted from ERDDAP tutorials on github

# find_pixels: finds the nearest pixel given the lat and lon of a station. Then finds a 3x3
#              box around that pixel

# calc_station_mean: calculates the spatial mean sst and chl
#                    for the 3x3 box of pixels at that station for each timestep
#                    returns the means as one xarray DataSet  

# plot_correlation: has code for scatter plot / best fit line / linear regression
#                   as well as a box plot to look at sst and chl correlations.
#                   saves your figure to the same local directory on your device as your code.

# plot_timeseries_correlation: plots the correlation for each week averaged across all years

# plot_timeseries_correlation_by_year: plots one line for sst and one for chl on the same subplot, 
#                                      one subplot for each year

#~~~~~~~~~~~~~~~~DOWNLOADING DATA~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
def download_data():
    # SST DATA
 
    #downloading the data from 2006-2010
    url_06_10 ='https://coastwatch.noaa.gov/erddap/griddap/noaacwecnAVHRRVIIRSmultisensorSSTeastcoast7Day.nc?sst%5B(2006-12-06T14:00:00Z):1:(2010-01-04T01:45:00Z)%5D%5B(0.0):1:(0.0)%5D%5B(40):1:(37)%5D%5B(-77):1:(-76)%5D'
    urllib.request.urlretrieve(url_06_10, "sst06-10.nc") #actually downloads the data
    
    #downloading the data from 2010-2015
    url_10_15 ='https://coastwatch.noaa.gov/erddap/griddap/noaacwecnAVHRRVIIRSmultisensorSSTeastcoast7Day.nc?sst%5B(2010-01-04T01:45:00Z):1:(2015-11-29T04:00:00Z)%5D%5B(0.0):1:(0.0)%5D%5B(40):1:(37)%5D%5B(-77):1:(-76)%5D'
    urllib.request.urlretrieve(url_10_15, "sst10-15.nc") #actually downloads the data
    
    #downloading the data from 2015-2020
    url_15_20 ='https://coastwatch.noaa.gov/erddap/griddap/noaacwecnAVHRRVIIRSmultisensorSSTeastcoast7Day.nc?sst%5B(2015-11-29T04:00:00Z):1:(2020-04-04T04:20:00Z)%5D%5B(0.0):1:(0.0)%5D%5B(40):1:(37)%5D%5B(-77):1:(-76)%5D'
    urllib.request.urlretrieve(url_15_20, "sst15-20.nc") #actually downloads the data
    
    # downloading the data from 2020 - 2025
    url_20_25 ='https://coastwatch.noaa.gov/erddap/griddap/noaacwecnAVHRRVIIRSmultisensorSSTeastcoast7Day.nc?sst%5B(2020-04-04T04:20:00Z):1:(2025-05-31T04:09:59Z)%5D%5B(0.0):1:(0.0)%5D%5B(40):1:(37)%5D%5B(-77):1:(-76)%5D'
    urllib.request.urlretrieve(url_20_25, "sst20-25.nc") #actually downloads the data
    
    # CHLORA DATA
    
    # downloading the data from 2020 - 2025
    url_20_25_cla = 'https://coastwatch.noaa.gov/erddap/griddap/noaacwecnOLCImultisensorCHLeastcoast7Day.nc?chlora%5B(2020-01-07):1:(2025-05-27T00:00:00Z)%5D%5B(40):1:(37)%5D%5B(-77):1:(-76)%5D'
    urllib.request.urlretrieve(url_20_25_cla, "chlorophyll20-25.nc") #actually downloads the data
    
    # downloading the data from 2016-2020
    url_16_20_cla = 'https://coastwatch.noaa.gov/erddap/griddap/noaacwecnOLCImultisensorCHLeastcoast7Day.nc?chlora%5B(2016-04-28):1:(2019-12-31)%5D%5B(40):1:(37)%5D%5B(-77):1:(-76)%5D'
    urllib.request.urlretrieve(url_16_20_cla, "chlorophyll16-20.nc") #actually downloads the data

#~~~~~~~~~~~~~~~~FUNCTIONS~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
   
def create_ds():
    ds_sst_15_20 = xr.open_dataset('sst15-20.nc', decode_cf=True) # loads the netcdf data into python and extract variables of interest
    ds_sst_20_25 = xr.open_dataset('sst20-25.nc', decode_cf=True) # loads the netcdf data into python and extract variables of interest
    
    ds_chlora_20_25 = xr.open_dataset('chlorophyll20-25.nc', decode_cf=True) # loads the netcdf data into python and extract variables of interest
    ds_chlora_16_20 = xr.open_dataset('chlorophyll16-20.nc', decode_cf=True) # loads the netcdf data into python and extract variables of interest

    ds_chlora = xr.concat([ds_chlora_16_20, ds_chlora_20_25], dim="time")
    ds_sst = xr.concat([ds_sst_15_20, ds_sst_20_25], dim="time")
    
  #  ds_chlora = ds_chlora.assign(week=ds_chlora.time.dt.isocalendar().week)
   # ds_sst = ds_sst.assign(week=ds_sst.time.dt.isocalendar().week)
    
    return ds_chlora, ds_sst

def process_data(chl_ds, sst_ds):
    # Step 1: Remove duplicates if any
    chl_ds_clean = chl_ds.drop_duplicates(dim='time')
    sst_ds_clean = sst_ds.drop_duplicates(dim='time')
    
    # Step 2: Spatial regridding (choose one approach)
    # Regrid chl to match sst grid
    
    #   REASONING (from claude)
        #This is scientifically more sound because:
        
        #You preserve the true spatial resolution of your coarsest dataset
        #You aggregate fine-scale chlorophyll data properly (no artificial detail)
        #Correlations will reflect real relationships at the SST's natural spatial scale
        
        #This way, you're asking: "How does chlorophyll (averaged to SST's spatial scale) 
        #correlate with SST?" rather than "How does artificially smoothed SST 
        #correlate with fine-scale chlorophyll?"
    
    chl_ds_regrid = chl_ds_clean.interp(latitude=sst_ds_clean.latitude, longitude=sst_ds_clean.longitude)
    
    # Step 3: Temporal alignment using week numbers (essentially 
    #"re-aligns" and creates consisten timestamps for each value
    # does week-year so it's not just an isolated week alignment (avoids duplicates
    #and confusion)

    chl_ds_weekly = chl_ds_regrid.resample(time='W').mean()
    sst_ds_weekly = sst_ds_clean.resample(time='W').mean()

    # Step 4: Then align on the resampled time coordinate
    chl_aligned, sst_aligned = xr.align(chl_ds_weekly, sst_ds_weekly, join='inner')

    
    # Step 4: Align both datasets on common weeks
    chl_final, sst_final = xr.align(chl_ds_weekly, sst_ds_weekly, join='inner')
    
    # Step 5: Combine into single dataset
    combined = xr.Dataset({
        'chlora': chl_final.chlora,
        'sst': sst_final.sst
    })
    
    return combined

def examine_data(dataset):
    print('The coordinates variables:', list(dataset.coords), '\n')    
    print('The data variables:', list(dataset.data_vars), '\n')
    
    if 'sst' in dataset:
        print('Shape/Structure of sst variable:')
        print(dataset.sst.shape)
    elif 'chlora' in dataset:
        print('Shape/Structure of chlora variable:')
        print(dataset.chlora.shape)
    else:
        print('Expected datavariable not in DataSet')
    
    print('\n Dates for each time step: ')
    print(dataset.time)
    
    # examine if the latitude coordinate variable is in ascending or descending order
    # by looking at the first and last values in the lat array
    # relevant for when you need to slice (subset) the Dataset later
    print('First latitude value', dataset.latitude[0].item())
    print('Last latitude value', dataset.latitude[-1].item())


def create_basicmap(data, timestep):
    min_chlora = np.nanmin(data.chlora[timestep])
    max_chlora = np.nanmax(data.chlora[timestep])
    
    min_sst = np.nanmin(data.sst[timestep])
    max_sst = np.nanmax(data.sst[timestep])
    
    print("minimum chlorophyll value: ", min_chlora)
    print("maximum chlorophyll value: ", max_chlora)
    
    print("minimum sst value: ", min_sst)
    print("maximum sst value: ", max_sst)
    
    
    #creating custom colormap based on minimum and maximum chlorophyll value
    levs = np.arange(np.floor(min_chlora), np.floor(max_chlora)+1, 0.05)
    num_colors = len(levs)
    
    colors = plt.cm.plasma(np.linspace(0, 1, num_colors))[::-1]
    cm = LinearSegmentedColormap.from_list("chlorophyll_cmap", colors, N=num_colors)
    
    
    # plot chlorophyll map
    plt.contourf(data.longitude, 
                 data.latitude, 
                 data.chlora[timestep, :, :], 
                 levs,
                 cmap=cm)
    
    # Plot the colorbar
    plt.colorbar()
    
    # Annotation: Add a point for each station
    plt.scatter(-76.0587, 38.5807, c='black') #ET5.2
    plt.scatter(-76.598, 38.1576, c='black') #LE2.2
    plt.scatter(-76.52254, 39.21309, c='black') #WT5.1
    plt.scatter(-76.38634, 37.24181, c='black') #WE4.2
    plt.scatter(-76.78889, 37.50869, c='black') #RET4.3
    
    # Annotation: Example of how to add a contour line (did not help here lol)
   # plt.contour(data.longitude, 
    #            data.latitude, 
     #           data.chlora[0, :, :], 
      #          levels=[14],
       #         linewidths=1)
    
    # Add a title
    plt.title("Chlorophyll-a at first timestep " 
              + data.time[timestep].dt.strftime('%b %Y').item())
    plt.show()

    #creating custom colormap based on minimum and maximum sst value
    levs = np.arange(np.floor(min_sst), np.floor(max_sst)+1, 0.05)
    num_colors = len(levs)
    
    colors = plt.cm.plasma(np.linspace(0, 1, num_colors))[::-1]
    cm = LinearSegmentedColormap.from_list("sst_cmap", colors, N=num_colors)
    
    
    # plot chlorophyll map
    plt.contourf(data.longitude, 
                 data.latitude, 
                 data.sst[timestep, :, :], 
                 levs,
                 cmap=cm)
    
    # Plot the colorbar
    plt.colorbar()
    
    # Annotation: Add a point for each station
    plt.scatter(-76.0587, 38.5807, c='black') #ET5.2
    plt.scatter(-76.598, 38.1576, c='black') #LE2.2
    plt.scatter(-76.52254, 39.21309, c='black') #WT5.1
    plt.scatter(-76.38634, 37.24181, c='black') #WE4.2
    plt.scatter(-76.78889, 37.50869, c='black') #RET4.3
    
    # Annotation: Example of how to add a contour line (did not help here lol)
   # plt.contour(data.longitude, 
    #            data.latitude, 
     #           data.chlora[0, :, :], 
      #          levels=[14],
       #         linewidths=1)
    
    # Add a title
    plt.title("SST at first timestep " 
              + data.time[timestep].dt.strftime('%b %Y').item())
    plt.show()
    
def find_pixels(station, lat, lon, data):
    # finds the index of the nearest lats and lons to the provided ones from each station
    # does this by extracting the list of lats and lons, subtracting the provided lat/lon value, and then finding which one has the smallest difference
    
    lat_index = np.abs(data.latitude.values - lat).argmin() # finds the index of the nearest existing latitude in the dataset
    lon_index = np.abs(data.longitude.values - lon).argmin() # finds the index of the nearest existing longitude in the dataset

    # actually finds the nearest lat/lon value
    #nearest_lat = data.latitude.values[lat_index]
    #nearest_lon = data.longitude.values[lon_index]
    
    if (lat_index > 0) and (lon_index > 0):
        if (lat_index < len(data.latitude.values)-1) and (lon_index < len(data.longitude.values)-1):
            station_data = data.sel(latitude=data.latitude.values[lat_index-1:lat_index+2],
                                    longitude=data.longitude.values[lon_index-1:lon_index+2])
            print(station+' data successfully selected')
        else:
            print('lat or lon out of upper bounds')
    else:
        print('lat or lon out of lower bounds')
    
    return station_data
    
def calc_station_mean(station, data):
    
    sst_station_mean = data.sst.mean(dim=['latitude', 'longitude'], skipna=True)
    print(station+' SST spatial mean successfully found')
    chl_station_mean = data.chlora.mean(dim=['latitude', 'longitude'], skipna=True)
    
    combined_mean = xr.Dataset({
        'chlora': chl_station_mean,
        'sst': sst_station_mean
    })
    
    combined_mean_cleaned = combined_mean.dropna(dim='time')
 
    return combined_mean_cleaned

def plot_correlation(station, combined_data):
    fig, ax = plt.subplots(figsize=(12, 6), dpi=300)
   # plot = ax.scatter(combined_data.sst, combined_data.chlora, color='#cc4778', alpha=0.7)
    
    # ~~~~ BEST FIT LINE (LINEAR REGRESSION) ~~~~
    '''
    x = chl_data.chlora
    y = sst_data.sst
    
    mask = np.isfinite(x) & np.isfinite(y)
    x_clean = x[mask]
    y_clean = y[mask]

    a = np.polyfit(chl_data.chlora, sst_data.sst, deg=1)
    bestfit = a[0]*chl_data.chlora + a[1]
    ax.plot(chl_data.chlora, bestfit, color='#0d0887')
    
    slope, intercept, r_value, p_value, std_err = stats.linregress(
        combined_data.sst, combined_data.chlora)
    bestfit = slope*combined_data.sst + intercept
    ax.plot(combined_data.sst, bestfit, color='#0d0887')
    
    print('\nSTATS')
    print('Slope: ' + str(slope))
    print('Intercept: '+ str(intercept))
    print('r value: '+ str(r_value))
    print('p value: ' + str(p_value))
    print('Standard deviation error: '+ str(std_err))
    print('\n')
    print('~~~~~~~~~~~~~~~~~~~~~~~~~~~')
    print('\n')
    '''
    # ~~~~ BOX PLOTS ~~~~    
    
    # bin data (groups of 5 degrees) and sort by bins
    combined_df = combined_data.to_dataframe().reset_index()
    combined_df['sst_bin'] = pd.cut(combined_df['sst'], 
                       bins=range(int(combined_df['sst'].min()), 
                                 int(combined_df['sst'].max()) + 6, 5),
                       include_lowest=True)
    
   # binned = combined_data.groupby_bins(combined_data.sst, 5)
    sns.boxplot(x='sst_bin', y='chlora', data=combined_df, color='rebeccapurple')
     
    
    # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
     
    ax.set_xlabel('SST (ºC)') #dependent variable
    ax.set_ylabel('Chlorophyll a (mg m^-3)')
    
    ax.set_title(station + ' SST vs. Chlorophyll Correlation')
    
    plt.show()
    
    fig.savefig(station+ ' SST vs. CHL Correlation.png', dpi=300)
    
def plot_timeseries_correlation(station, combined_data):
    # Create a copy to avoid modifying the original dataframe
    combined_data_copy = combined_data.copy()
    
    #group by week and find the mean of each group
    weekly_means_timeseries_ds = combined_data_copy.groupby('time.week').mean()
    #print(grouped)
    print("Successfully found means for each week across timeseries")
    
    fig, ax = plt.subplots(figsize=(12, 6), dpi=300)
    plot = ax.plot(weekly_means_timeseries_ds.week, weekly_means_timeseries_ds.sst, 
                      color='#cc4778', alpha=0.7, label='SST')
    ax.plot(weekly_means_timeseries_ds.week, weekly_means_timeseries_ds.chlora,
            color='#0d0887', alpha=0.7, label='Chlorophyll a')
    
    
    ax.set_xlim([0, 53])
    ax.set_ylim([0, 110])
    
    months_list = list(range(0, 12)) 
    months_ticks_list = [i*4.3 for i in months_list]
    months_labels_list =  ['Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun', 'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec']
    
     
    ax.set_xticks(months_ticks_list)
    ax.set_xticklabels(months_labels_list)
    
    ax.set_xlabel('Month') #dependent variable
    ax.set_ylabel('Chlorophyll a (mg m^-3) | SST (ºC)')
    ax.legend()
    
    ax.set_title(station + ' SST vs. Chlorophyll Timeseries Correlation')
    
    plt.show()
    
    fig.savefig(station+ ' SST and CHL Avged Timeseries Corr.png', dpi=300)
    
    return(plot)
    
def plot_timeseries_correlation_by_year(station, combined_data):
    years_list = np.array([[2016, 2017, 2018, 2019],
                          [2020, 2021, 2022, 2023]])
    
    #creating subplots
    fig, axes = plt.subplots(nrows=2, ncols=4, figsize=(30, 13))
    plots = [ ]
    row = 0
    col = 0

    
    while row<2:
        col=0
        while col <4:
            year_data = combined_data.sel(time=str(years_list[row][col]))
            year_data = year_data.sortby('time.week')
            print(year_data.time.dt.week)
            
            print(row, col)
            
            plot = axes[row][col].plot(year_data.time.dt.week, year_data.sst, 
                                     color='#cc4778', alpha=0.7, label='SST (ºC)')
            axes[row][col].plot(year_data.time.dt.week, year_data.chlora, 
                                     color='#0d0887', alpha=0.7, label='Chlorophyll (mg m^-3)')
            
            axes[row][col].set_xlim([0, 53])
            axes[row][col].set_ylim([0, 110])
            
            months_list = list(range(0, 12)) 
            months_ticks_list = [i*4.3 for i in months_list]
            months_labels_list =  ['Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun', 'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec']
            
            
            # Set minor ticks for grid (these will show grid lines but no labels)
            #axes[row][col].set_xticks(year_data.time.dt.week.values(), minor=True)
             
            axes[row][col].set_xticks(months_ticks_list)
            axes[row][col].set_xticklabels(months_labels_list)
            
            axes[row][col].set_xlabel('Month') #dependent variable
            axes[row][col].set_ylabel('SST/CHLa')
            
            axes[row][col].set_title(station + ' ' + str(years_list[row][col]))
            
            
            plots.append(plot)
            
            
            col+=1
                     
        row+=1
    plt.legend()
  #  plt.title(station+' Chl & SST time series by year')
    
    plt.show()
    
    fig.savefig(station+ ' SST and CHL Timeseries Corr by year.png', dpi=300)

    
    
    
#~~~~~~~~~~~~~~~~~~~~~~~~~CODE~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
sns.set()
plt.rcParams['font.family'] = 'serif'

ds_chlora, ds_sst = create_ds()
ds_sst_cleaned = ds_sst.squeeze('level', drop=True)

combined_ds = process_data(ds_chlora, ds_sst_cleaned)
examine_data(combined_ds)
print(combined_ds)

#create_basicmap(combined_ds, 0)

stations = ['ET5.2', 'LE2.2', 'WT5.1', 'WE4.2', 'RET4.3']
lats = [38.580, 38.1576, 39.21309, 37.24181, 37.50869]
lons = [-76.0587, -76.598, -76.52254, -76.38634, -76.78889]

i=0

while i<5:
    print('BEGINNING ANALYSIS FOR ' + str(stations[i]))
    
    print('Processing for regular SST and CHL data')
    station_data = find_pixels(stations[i], lats[i], lons[i], combined_ds)
    station_means = calc_station_mean(stations[i], station_data)
    #station_means_winter = isolate_winter_weeks(stations[i], station_mean)

   # print('SST station mean: ', station_means.sst)
  #  print('CHL station mean: ', station_means.chlora)
  
    #plot_timeseries_correlation(stations[i], station_means)
    #plot_timeseries_correlation_by_year(stations[i], station_means)
 

    #print('Plotting Correlation')
    plot_correlation(stations[i], station_means)
    
    i+=1


 
  #  plot_timeseries_weekly(stations[i], station_anom, 'allyear')
    #plot_timeseries_weekly(stations[i], station_anom_winter, 'winter')



