ER2 Retrieval Simulation

Simulation of a measurement from the ER2 aircraft.

from ali.test.util.rt_options import simulation_rt_opts
from ali.retrieval.extinction import ExtinctionRetrieval
import numpy as np
import sasktran as sk
from ali.instrument.simulator import ImagingSimulator
from ali.atmosphere.curtain import SimulationAtmosphere
from ali.util.geometry import optical_axis_from_geometry
from ali.util.config import Config
from ali.test.util.atmospheres import simulation_atmo
from ali.test.util.sensors import simulation_sensor
import matplotlib.pyplot as plt
import os
from ali.atmosphere.gem import GEMAtmosphere
from ali.util.analysis import resolution_from_averaging_kernel, encode_multiindex, decode_to_multiindex
import xarray as xr
from ali.test.util.atmospheres import aerosol_cross_section
from ali.util.er2 import load_er2_data, er2_orientation
import matplotlib
matplotlib.use('Qt5Agg')


plt.style.use(Config.MATPLOTLIB_STYLE_FILE)


def basic_extinction_retrieval(clouds=False, lat=33.0,
                               mjd=54372, two_dim=True,
                               vertical_resolution: float = None,
                               test_file: str = None):
    """
    Test retrieval of extinction using the ER2 sensor.
    """

    np.random.seed(0)
    # ------------------------------ ATMOSPHERE ----------------------------- #
    altitudes = np.arange(500.0, 45001.0, 200)
    atmo_sim = simulation_atmo(lat, mjd, altitudes, clouds=clouds, two_dim=two_dim, cloud_scaling=0.1)

    # ------------------------------ GEOMETRY ----------------------------- #
    er2_data = load_er2_data('21Jul2017-2045')
    idx = 20000
    er2_data = er2_data.isel(time=idx)
    er2_geometry = er2_orientation(er2_data, instrument_pitch_deg=-4.3)

    # ------------------------------ SENSORS ----------------------------- #
    sensor_wavel = np.array([750.0, 875.0, 925.0, 1025., 1250.0, 1440.0, 1550.0])
    exposure = [1.0] * len(sensor_wavel)
    sensors = []
    opt_geom = []
    straylight = 0.0
    for wavel, integration_time in zip(sensor_wavel, exposure):
        sensor = simulation_sensor(wavelength_nm=[wavel], simulated_pixel_averaging=50,
                                   image_vert_fov=5, image_horiz_fov=5.0,
                                   noisy=True, vertical_resolution=vertical_resolution,
                                   straylight=straylight, dual_channel=False, er2=True)
        # sensor.exposure_time = integration_time
        sensor.auto_exposure = True
        sensor.max_exposure = 30
        sensors.append(sensor)
        # opt_geom.append(optical_geometry)
        opt_geom.append(er2_geometry)

    # ------------------------------ SIMULATOR ----------------------------- #
    sim_opts = simulation_rt_opts(configure_for_cloud=clouds, cloud_lower_bound=5000.0,
                                  cloud_upper_bound=18000.0, two_dim=two_dim)
    simulator = ImagingSimulator(sensors=sensors, optical_axis=opt_geom,
                                 atmosphere=atmo_sim, options=sim_opts)
    # simulator.sun = geometry.sun
    simulator.grid_sampling = True
    simulator.group_scans = False
    simulator.num_vertical_samples = 512
    simulator.dual_polarization = True
    simulator.save_diagnostics = True

    measurement_l1 = simulator.calculate_radiance()
    tanalts = measurement_l1[0].tangent_locations().altitude / 1000

    retrieval = ExtinctionRetrieval(sensors, opt_geom, measurement_l1)
    retrieval.aerosol_vector_wavelength = [750.0]
    retrieval.max_iterations = 5
    retrieval.vertical_samples = 512
    if test_file is None:
        test_file = os.path.join(os.path.dirname(__file__), 'data', f'test-er2-retrieval-{lat:.2f}-{straylight:.4f}.nc')
    if os.path.isfile(test_file):
        os.remove(test_file)
    retrieval.output_filename = test_file
    retrieval.simulation_atmosphere = atmo_sim
    # retrieval.brdf = 0.65
    retrieval.brdf = atmo_sim.brdf.albedo
    retrieval.couple_normalization_altitudes = False
    retrieval.use_cloud_for_lower_bound = False
    retrieval.normalization_altitudes = [20000.0, 22000.0]
    # retrieval.tikhonov_factor *= 0.5
    retrieval.retrieve()
    retrieval.plot_error(test_file)
    plot_results(test_file)


def plot_level1(measurement_l1):

    exp = np.array([float(m.data.exposure_time.values) for m in measurement_l1])
    wavel = np.array([float(m.data.wavelength.values) for m in measurement_l1])
    fig, ax = plt.subplots(1, 1, figsize=(5, 3), dpi=200)
    fig.subplots_adjust(bottom=0.12, right=0.97, top=0.97)
    ax.plot(wavel[0::2], exp[0::2], label='Vertical')
    ax.plot(wavel[1::2], exp[1::2], label='Horizontal')
    ax.set_ylabel('Exposure Time [s]')
    ax.set_xlabel('Wavelength [nm]')
    ax.set_ylim(0, 120)
    ax.set_xlim(700, 1400)
    leg = ax.legend(framealpha=1, facecolor=ax.get_facecolor(), edgecolor='none', fontsize='small')
    leg.set_title('Polarization', prop={'size': 'small', 'weight': 'bold'})

    tanalts = measurement_l1[0].tangent_locations().altitude / 1000
    fig, ax = plt.subplots(1, 1, figsize=(5, 3), dpi=200)
    fig.subplots_adjust(bottom=0.12, right=0.97, top=0.97)
    color = plt.cm.Spectral
    w0 = wavel.min()
    w1 = wavel.max()
    for idx, m in enumerate(measurement_l1[0::2]):
        w = float(m.data.wavelength.values)
        ax.plot(m.data.radiance, tanalts, color=color((w - w0) / (w1 - w0)), label=f'{w:0.0f} nm')
    leg = ax.legend(framealpha=1, facecolor=ax.get_facecolor(), edgecolor='none', fontsize='small')
    leg.set_title('Wavelength', prop={'size': 'small', 'weight': 'bold'})
    ax.set_ylabel('Tangent Altitude [km]')
    ax.set_xlabel('Radiance [photons/nm/s/st/cm$^2$]')
    ax.set_xlim(0, 5e13)
    ax.set_ylim(0, 22.5)


def plot_sensor(sensor):

    wavelengths = np.arange(600, 1500.1, 10.0)
    t = {}
    for key in sensor._optics.keys():
        t[key] = np.ones_like(wavelengths)
    t['total'] = np.ones_like(wavelengths)
    for idx, w in enumerate(wavelengths):
        for comp in sensor._optics.keys():
            t[comp][idx] = sensor._optics[comp].matrix(w)[0, 0]
        t['total'][idx] = sensor.optics.matrix(w)[0, 0]

    fig, ax = plt.subplots(1, 1, figsize=(5, 3), dpi=200)
    fig.subplots_adjust(bottom=0.12, right=0.97, top=0.97)
    for key in t.keys():
        if key == 'total':
            ax.plot(wavelengths, t[key], label=key, color='k', lw=1)
        else:
            ax.plot(wavelengths, t[key], label=key)

    leg = ax.legend(framealpha=1, facecolor=ax.get_facecolor(), edgecolor='none', fontsize='small')
    leg.set_title('Component', prop={'size': 'small', 'weight': 'bold'})
    ax.set_ylabel('Transmission')
    ax.set_xlabel('Wavelength [nm]')
    ax.set_xlim(600, 1500)
    ax.set_ylim(0, 1)
    # ax.set_xlim(0, 5e13)


def plot_results(test_file: str = None):

    if test_file is None:
        test_file = os.path.join(os.path.dirname(__file__), 'data', 'test-er2-retrieval-33.00-0.0000.nc')
    fig, ax = ExtinctionRetrieval.plot_results(test_file, aerosol_scale=1000,
                                               kernel_kwargs=dict(ret_alts=np.arange(7, 35, 3),
                                                                  fwhm_axis=False, alpha_scale=8))
    ax[2].set_xlim(-0.4, 1.5)
    ax[1].set_xlim(0, 0.85)
    ax[0].set_xlim(0, 1.9)
    ax[0].set_ylim(5, 35)
    # ax[0].set_xscale('log')
    # ax[0].set_xlim(1e-7, 0.5e-2)
    # ax[0].set_ylim(5, 40)
    fig.savefig(os.path.join(Config.DIAGNOSTIC_FIGURE_FOLDER, 'extinction-retrieval-scaled-cloud.png'), dpi=600)
    fig.savefig(os.path.join(Config.DIAGNOSTIC_FIGURE_FOLDER, 'extinction-retrieval-test.pdf'))


if __name__ == '__main__':

    # plot_results()
    # resolution = np.linspace(0.001, 0.01, 10)
    # for res in resolution:
    #     test_file = os.path.join(os.path.dirname(__file__), 'data', f'test-extinction-retrieval-{res:.4f}.nc')
    basic_extinction_retrieval(clouds=True, two_dim=False, lat=33.0)
../../_images/er2_retrieval_example.png