Source code for ali.retrieval.measvec
import numpy as np
import xarray as xr
from typing import Union, List, Dict
from skretrieval.core.radianceformat import RadianceBase
[docs]
class Transformer:
"""
These classes can be used to compose functions applied to a radiance for use in a `MeasurementVectorElement`.
"""
def transform(self, l1_data: RadianceBase, wf=None, covariance=False):
if 'error' in l1_data.data.keys():
if covariance:
l1_data.data['covariance'] = self.covariance(l1_data)
return l1_data
def jacobian(self, l1_data):
return np.diag(np.ones(len(l1_data.data.los)))
def covariance(self, l1_data):
if 'covariance' in l1_data.data.keys():
l1_covariance = l1_data.data['covariance']
else:
l1_covariance = np.diag(l1_data.data['error'] ** 2)
J = self.jacobian(l1_data)
return xr.DataArray(J @ l1_covariance @ J.T, dims=['los', 'los2'],
coords=[l1_data.data.los.values, l1_data.data.los.values])
[docs]
class MeasurementVectorElement:
"""
Applies a set of transforms to a level 1 measurement.
"""
def __init__(self, name=None):
self._transforms = {}
self._name = name
@property
def transforms(self):
return self._transforms
@property
def name(self):
return self._name
[docs]
def add_transform(self, transform: Transformer, index: Union[int, str] = 'post'):
"""
Add a transform to the measurment vector element.
Parameters
----------
transform : Transformer
transformation to apply to the measurements.
index :
Order in which the transform will be applied. If `post` (default) the transform will be appended to the
current list. If `pre` transform be prepended. If an integer the transform will be placed at position
`index` in the current list of transforms.
"""
if type(index) == str:
if self._transforms:
if index.lower() == 'pre':
index = sorted(self._transforms.keys())[0] - 1
elif index.lower() == 'post':
index = sorted(self._transforms.keys())[-1] + 1
else:
index = 0
if type(index) == int:
if index in self._transforms.keys():
for key in sorted(self._transforms.keys(), reverse=True):
self._transforms[index + 1] = self._transforms[index]
if key <= index:
break
self._transforms[index] = transform
else:
raise ValueError(f'index should be a `pre`, `post` or an integer, got {index} with type {type(index)}')
def remove_transform(self, index):
del self._transforms[index]
[docs]
def transform(self, l1_data, covariance=False):
"""
Transform the level 1 data into measurerment vector space.
"""
for t in sorted(self._transforms.keys()):
l1_data = self._transforms[t].transform(l1_data, covariance=covariance)
return l1_data
[docs]
def meas_dict(self, l1_data: Union[List[RadianceBase], RadianceBase], covariance: bool = False) -> Dict[str, np.ndarray]:
"""
Transform the level 1 data into measurement vector space and return the measurement vector, error and weighting
functions.
Parameters
----------
l1_data:
level 1 data that the transforms will be applied to.
covariance:
Whether to compute the full covariance matrix or only the diagonal. Default False
"""
data = self.transform(l1_data, covariance=covariance).data
result = dict()
result['y'] = data.radiance.values
if result['y'].shape == ():
result['y'] = np.array([result['y']])
result['y'][~np.isfinite(result['y'])] = np.nan
if 'error' in data.keys():
result['y_error'] = data['error'].values ** 2
wf_keys = [key for key in data.keys() if 'wf' in key]
if wf_keys:
for key in wf_keys:
result[key] = data[key].transpose('los', 'perturbation', ...)
return result
[docs]
class MeasurementVector:
"""
A full measurement vector made up of a collection of measurement vector elements.
"""
def __init__(self, meas_vec_elements: List[MeasurementVectorElement], covariance=False, drop_zero_error=True):
self._elements = meas_vec_elements
self._covariance = covariance
self._drop_zero_error = drop_zero_error
@property
def elements(self) -> List[MeasurementVectorElement]:
return self._elements
[docs]
def meas_dict(self, l1_data) -> Dict[str, np.ndarray]:
"""
Transform the level 1 data into the measurement vector space and return the measurement vector, error and weighting
functions.
Parameters
----------
l1_data:
level 1 data that the transforms will be applied to.
"""
md = []
for el in self._elements:
md.append(el.meas_dict(l1_data, covariance=self._covariance))
meas_dict = {}
for key in md[0]:
if len(md[0][key].shape) == 1:
meas_dict[key] = np.hstack([m[key] for m in md])
else:
meas_dict[key] = np.vstack([m[key] for m in md])
if self._drop_zero_error:
if 'y_error' in meas_dict.keys():
if np.nansum(meas_dict['y_error']) == 0.0:
del meas_dict['y_error']
return meas_dict