"""Grid utilities: cropping, interpolation, and event-centred patch extraction.
Supports arbitrary input resolution and domain — crops to NH and
bilinearly interpolates to a regular 1.5° grid matching ERA5-style layout.
"""
from __future__ import annotations
import numpy as np
from scipy.interpolate import RegularGridInterpolator
from dataclasses import dataclass, field
from typing import Optional
from .constants import (
TARGET_LAT, TARGET_LON, LAT_HALF, LON_HALF, R_EARTH,
)
[docs]
@dataclass(frozen=True)
class NHGrid:
"""Northern Hemisphere regular lat-lon grid.
Attributes:
lat: 1-D latitude array, descending (90 → 0).
lon: 1-D longitude array (-180 → 180).
dlat: Grid spacing in latitude [deg].
dlon: Grid spacing in longitude [deg].
"""
lat: np.ndarray
lon: np.ndarray
@property
def dlat(self) -> float:
"""Grid spacing in latitude [deg]."""
return float(abs(np.nanmean(np.diff(self.lat))))
@property
def dlon(self) -> float:
"""Grid spacing in longitude [deg]."""
return float(abs(np.nanmean(np.diff(self.lon))))
@property
def nlat(self) -> int:
"""Number of latitude points."""
return len(self.lat)
@property
def nlon(self) -> int:
"""Number of longitude points."""
return len(self.lon)
@property
def lat_descending(self) -> bool:
"""True if latitude array is in descending order."""
return bool(np.all(np.diff(self.lat) < 0))
@property
def dy(self) -> float:
"""Meridional grid spacing in metres."""
return np.deg2rad(self.dlat) * R_EARTH
@property
def dx_arr(self) -> np.ndarray:
"""Zonal grid spacing per latitude row [m], shape (nlat,)."""
dx = np.deg2rad(self.dlon) * R_EARTH * np.cos(np.deg2rad(self.lat))
return np.maximum(dx, self.dy * 0.01)
[docs]
def default_nh_grid() -> NHGrid:
"""Return the standard 1.5° NH grid (90°N–0°, -180°–180°)."""
return NHGrid(lat=TARGET_LAT.copy(), lon=TARGET_LON.copy())
def crop_to_nh(lat: np.ndarray, lon: np.ndarray,
data: np.ndarray, lat_axis: int = -2
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""Crop data to Northern Hemisphere (lat >= 0).
Parameters:
lat: Input latitude array.
lon: Input longitude array.
data: N-D array with latitude along `lat_axis`.
lat_axis: Axis index for latitude.
Returns:
(nh_lat, lon, nh_data)
"""
mask = lat >= 0
nh_lat = lat[mask]
slices = [slice(None)] * data.ndim
slices[lat_axis] = mask
return nh_lat, lon, data[tuple(slices)]
def bilinear_interpolate(
src_lat: np.ndarray,
src_lon: np.ndarray,
data: np.ndarray,
dst_lat: np.ndarray = TARGET_LAT,
dst_lon: np.ndarray = TARGET_LON,
) -> np.ndarray:
"""Bilinearly interpolate 2-D or N-D data to target grid.
The last two axes are assumed to be (lat, lon).
Parameters:
src_lat: Source latitude (ascending or descending).
src_lon: Source longitude.
data: Array with shape (..., nlat_src, nlon_src).
dst_lat: Target latitude array.
dst_lon: Target longitude array.
Returns:
Interpolated array with shape (..., nlat_dst, nlon_dst).
"""
# Ensure ascending lat for interpolator
if src_lat[0] > src_lat[-1]:
src_lat = src_lat[::-1]
data = data[..., ::-1, :]
# Build target mesh
dst_lat_g, dst_lon_g = np.meshgrid(dst_lat, dst_lon, indexing="ij")
points = np.stack([dst_lat_g.ravel(), dst_lon_g.ravel()], axis=-1)
orig_shape = data.shape[:-2]
flat = data.reshape(-1, data.shape[-2], data.shape[-1])
out = np.empty(
(flat.shape[0], len(dst_lat), len(dst_lon)), dtype=data.dtype
)
for i in range(flat.shape[0]):
interp = RegularGridInterpolator(
(src_lat, src_lon), flat[i],
method="linear", bounds_error=False, fill_value=np.nan,
)
out[i] = interp(points).reshape(len(dst_lat), len(dst_lon))
return out.reshape(*orig_shape, len(dst_lat), len(dst_lon))
[docs]
@dataclass
class EventPatch:
"""Event-centred patch extraction from a full NH grid.
Attributes:
grid: The underlying NHGrid.
lat_half: Half-window in latitude [deg].
lon_half: Half-window in longitude [deg].
"""
grid: NHGrid
lat_half: float = LAT_HALF
lon_half: float = LON_HALF
@property
def lat_pad(self) -> int:
"""Number of grid points of padding in latitude."""
return int(round(self.lat_half / self.grid.dlat))
@property
def lon_pad(self) -> int:
"""Number of grid points of padding in longitude."""
return int(round(self.lon_half / self.grid.dlon))
@property
def patch_shape(self) -> tuple[int, int]:
"""Shape of the extracted patch (nlat_patch, nlon_patch)."""
return (2 * self.lat_pad + 1, 2 * self.lon_pad + 1)
def relative_grid(self) -> tuple[np.ndarray, np.ndarray]:
"""Return relative coordinate arrays (Y_rel, X_rel) in degrees."""
rlat = np.linspace(-self.lat_half, self.lat_half, 2 * self.lat_pad + 1)
rlon = np.linspace(-self.lon_half, self.lon_half, 2 * self.lon_pad + 1)
Y_rel, X_rel = np.meshgrid(rlat, rlon, indexing="ij")
return Y_rel, X_rel
def nearest_idx(self, lat0: float, lon0: float
) -> tuple[int, int, bool]:
"""Find nearest grid index and check if patch fits.
Parameters:
lat0: Event centre latitude [deg].
lon0: Event centre longitude [deg].
Returns:
(ilat, ilon, ok) where ok=True means the full patch fits
within the latitude bounds.
"""
ilat = int(np.abs(self.grid.lat - lat0).argmin())
ilon = int(np.abs(self.grid.lon - lon0).argmin())
ok = (ilat >= self.lat_pad and
ilat + self.lat_pad < self.grid.nlat)
return ilat, ilon, ok
def wrapped_lon_index(self, ilon: int) -> np.ndarray:
"""Return longitude indices with periodic wrapping.
Parameters:
ilon: Centre longitude index.
Returns:
Array of longitude indices of length (2 * lon_pad + 1).
"""
start = ilon - self.lon_pad
return (np.arange(2 * self.lon_pad + 1) + start) % self.grid.nlon
def extract(self, data: np.ndarray, ilat: int, ilon: int,
eff_north: Optional[int] = None,
eff_south: Optional[int] = None) -> np.ndarray:
"""Extract event-centred patch from (..., nlat, nlon) data.
Handles zonal wrap and asymmetric polar padding.
Parameters:
data: Array with last two dims (nlat, nlon).
ilat: Centre latitude index.
ilon: Centre longitude index.
eff_north: Effective northward padding (default: lat_pad).
eff_south: Effective southward padding (default: lat_pad).
Returns:
Patch array of shape (..., 2*lat_pad+1, 2*lon_pad+1).
NaN-filled where data doesn't reach.
"""
if eff_north is None:
eff_north = self.lat_pad
if eff_south is None:
eff_south = self.lat_pad
lon_idx = self.wrapped_lon_index(ilon)
full_h = 2 * self.lat_pad + 1
out_shape = data.shape[:-2] + (full_h, len(lon_idx))
out = np.full(out_shape, np.nan, dtype=data.dtype)
if self.grid.lat_descending:
i0 = max(0, ilat - eff_north)
i1 = min(self.grid.nlat, ilat + eff_south + 1)
else:
i0 = max(0, ilat - eff_south)
i1 = min(self.grid.nlat, ilat + eff_north + 1)
lat_slice = data[..., i0:i1, :]
lon_sub = lat_slice[..., lon_idx]
if self.grid.lat_descending:
lon_sub = lon_sub[..., ::-1, :]
y_eff = lon_sub.shape[-2]
y0 = self.lat_pad - eff_south
out[..., y0:y0 + y_eff, :] = lon_sub
return out