"""Composite accumulation from per-event NPZ files.
Corresponds to **Pass 2** of the core script
``ss01_rwb_stage_multilevel_composites.py``.
Usage (via CLI)::
pvtend-pipeline composite \\
--npz-dir /path/to/composite_blocking_tempest \\
--rwb-pkl /path/to/outputs/rwb_variant_tracksets.pkl \\
--output /path/to/outputs/composite.pkl
Or programmatically::
from pvtend.composite_builder import build_composites, CompositeConfig
from pvtend.classify import ClassifyResult
rwb = ClassifyResult.load("rwb_variant_tracksets.pkl")
cfg = CompositeConfig(npz_dir=Path("..."), stages=["onset","peak","decay"])
comp = build_composites(cfg, rwb)
comp.save("composite.pkl")
"""
from __future__ import annotations
import pickle
import re
from collections import defaultdict
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Sequence
import numpy as np
from .classify import ClassifyResult, _parse_track_id, _parse_dh, _load_excluded
# ── Picklable defaultdict factories (lambdas are NOT picklable) ──────
def _dd_dict() -> defaultdict:
return defaultdict(dict)
def _dd_int() -> defaultdict:
return defaultdict(int)
# ── Metadata keys in NPZ — skip when accumulating ────────────────────
_META = frozenset({
"Y_rel", "X_rel", "levels", "wavg_levels", "H_SCALE", "G0",
"lat_vec", "lon_vec_unwrapped",
"track_id", "lat0", "lon0", "center_lat", "center_lon",
"center_mode", "ts", "dh",
})
def _levels_indexer(
levels_file: np.ndarray, levels_ref: np.ndarray,
) -> np.ndarray | None:
"""Map file levels → reference levels (index array or None)."""
if (levels_file.shape == levels_ref.shape
and np.all(levels_file == levels_ref)):
return np.arange(levels_ref.size, dtype=int)
pos = {int(lv): i for i, lv in enumerate(levels_file.tolist())}
try:
return np.array(
[pos[int(lv)] for lv in levels_ref.tolist()], dtype=int
)
except (KeyError, ValueError):
return None
def _accumulate(
sums: dict[str, np.ndarray],
valids: dict[str, np.ndarray],
key: str,
arr: np.ndarray,
) -> None:
"""NaN-safe in-place accumulation."""
mask = np.isfinite(arr)
a0 = np.where(mask, arr, 0.0)
if key not in sums:
sums[key] = a0.astype(np.float64, copy=True)
valids[key] = mask.astype(np.uint16, copy=True)
else:
sums[key] += a0
valids[key] += mask
# ── Config ────────────────────────────────────────────────────────────
[docs]
@dataclass
class CompositeConfig:
"""Configuration for Pass-2 composite accumulation.
Attributes:
npz_dir: Root directory with ``{stage}/dh=±N/*.npz``.
stages: Event stages to process.
exclude_file: Optional exclude-track CSV.
"""
npz_dir: Path = Path(".")
stages: list[str] = field(
default_factory=lambda: ["onset", "peak", "decay"]
)
exclude_file: Path | None = None
# ── Result container ──────────────────────────────────────────────────
[docs]
@dataclass
class CompositeResult:
"""Accumulated composite data, supporting *original* + RWB variants.
Variants exposed:
``original`` — all events (no RWB filter);
``AWB_{stage}``, ``CWB_{stage}``, ``NEUTRAL_{stage}``
for each stage.
Access composites via :meth:`mean_3d` and :meth:`reduce_2d`.
"""
levels: np.ndarray
x_rel: np.ndarray
y_rel: np.ndarray
h_scale: float | None
stages: list[str]
fields_3d: list[str]
# ``original`` accumulators — {evt: {dh: {field: arr}}}
sums: dict[str, dict[int, dict[str, np.ndarray]]]
valids: dict[str, dict[int, dict[str, np.ndarray]]]
counts: dict[str, dict[int, int]]
# RWB-variant accumulators — {variant: {evt: {dh: {field: arr}}}}
sums_v: dict[str, dict[str, dict[int, dict[str, np.ndarray]]]]
valids_v: dict[str, dict[str, dict[int, dict[str, np.ndarray]]]]
counts_v: dict[str, dict[str, dict[int, int]]]
variant_names: list[str]
# ── access helpers ──
def _pick(
self, variant: str | None, stage: str, dh: int,
) -> tuple[dict, dict, int]:
if not variant or str(variant).lower() == "original":
s = self.sums.get(stage, {}).get(dh, {})
v = self.valids.get(stage, {}).get(dh, {})
c = self.counts.get(stage, {}).get(dh, 0)
else:
s = self.sums_v.get(variant, {}).get(stage, {}).get(dh, {})
v = self.valids_v.get(variant, {}).get(stage, {}).get(dh, {})
c = self.counts_v.get(variant, {}).get(stage, {}).get(dh, 0)
return s, v, c
def mean_3d(
self,
field: str,
stage: str,
dh: int,
*,
variant: str | None = "original",
) -> np.ndarray | None:
"""Return the NaN-safe mean 3-D composite array."""
s, v, _ = self._pick(variant, stage, dh)
arr_sum = s.get(field)
vcount = v.get(field)
if arr_sum is None or vcount is None:
return None
arr = np.asarray(arr_sum, dtype=np.float64)
vc = np.asarray(vcount, dtype=np.float64)
out = np.full_like(arr, np.nan)
mask = vc > 0
np.divide(arr, vc, out=out, where=mask)
return out
def reduce_2d(
self,
field: str,
stage: str,
dh: int,
*,
variant: str | None = "original",
level_mode: str | int | None = None,
) -> np.ndarray | None:
"""Reduce a 3-D composite to 2-D.
``level_mode=None|"all"|"3d"`` → return full 3-D array.
``level_mode="wavg"`` → ``exp(−z/H)`` weighted average over 300, 250, 200 hPa.
``level_mode=300`` → nearest pressure level slice.
"""
arr3d = self.mean_3d(field, stage, dh, variant=variant)
if arr3d is None:
return None
if level_mode in (None, "", "all", "3d"):
return arr3d
if isinstance(level_mode, str) and level_mode.lower() in {
"wavg", "w-avg", "weighted",
}:
# exp(−z/H) weighted average over 300, 250, 200 hPa
# (matches tendency.py vwm — canonical pvtend recipe)
from .constants import WAVG_LEVELS as _WL, H_SCALE as _HS, G0 as _G0
wavg_hpa = np.asarray(_WL, dtype=float)
levels_arr = np.asarray(self.levels, dtype=float)
indices = [int(np.nanargmin(np.abs(levels_arr - lv)))
for lv in wavg_hpa]
slices = arr3d[indices] # (3, NY, NX)
z_name = "z_3d" if "z_3d" in self.fields_3d else "z"
z3d = self.mean_3d(z_name, stage, dh, variant=variant)
if z3d is None:
raise ValueError("Need z_3d for wavg")
z_m = z3d[indices] / _G0 # geopotential → metres
h = float(self.h_scale) if self.h_scale is not None else _HS
wt = np.exp(-z_m / h)
num = np.nansum(slices * wt, axis=0)
den = np.nansum(wt, axis=0)
out = np.full(num.shape, np.nan, dtype=np.float64)
m = den > 0
out[m] = num[m] / den[m]
return out
try:
lev_val = float(level_mode)
except (TypeError, ValueError) as exc:
raise ValueError(f"Unsupported level_mode {level_mode!r}") from exc
levels = np.asarray(self.levels, dtype=float)
idx = int(np.nanargmin(np.abs(levels - lev_val)))
return arr3d[idx]
def available_dh(
self, stage: str, *, variant: str | None = "original",
) -> list[int]:
if not variant or str(variant).lower() == "original":
return sorted(self.counts.get(stage, {}).keys())
return sorted(
self.counts_v.get(variant, {}).get(stage, {}).keys()
)
# ── I/O ──
def save(self, path: Path | str) -> Path:
path = Path(path)
path.parent.mkdir(parents=True, exist_ok=True)
with open(path, "wb") as f:
pickle.dump(self, f, protocol=pickle.HIGHEST_PROTOCOL)
print(f"[saved] composite → {path}", flush=True)
return path
@classmethod
def load(cls, path: Path | str) -> "CompositeResult":
path = Path(path)
with open(path, "rb") as f:
obj = pickle.load(f)
return obj
# ── Builder ───────────────────────────────────────────────────────────
[docs]
def build_composites(
cfg: CompositeConfig,
rwb: ClassifyResult | None = None,
) -> CompositeResult:
"""Accumulate NPZ fields into variant-aware composites.
Args:
cfg: Composite configuration (directories, stages).
rwb: Optional RWB classification result. If *None*, only the
``original`` variant (all events) is produced.
Returns:
:class:`CompositeResult` with accumulated sums/counts.
"""
excluded = _load_excluded(cfg.exclude_file)
variant_trackset = rwb.variant_trackset if rwb is not None else {}
variants = list(variant_trackset.keys())
# ── accumulators ──
sums: dict[str, dict[int, dict]] = defaultdict(_dd_dict)
valids: dict[str, dict[int, dict]] = defaultdict(_dd_dict)
counts: dict[str, dict[int, int]] = defaultdict(_dd_int)
sums_v: dict[str, dict[str, dict[int, dict]]] = {
v: defaultdict(_dd_dict) for v in variants
}
valids_v: dict[str, dict[str, dict[int, dict]]] = {
v: defaultdict(_dd_dict) for v in variants
}
counts_v: dict[str, dict[str, dict[int, int]]] = {
v: defaultdict(_dd_int) for v in variants
}
LEVELS: np.ndarray | None = None
X_REL = Y_REL = None
H_SCALE: float | None = None
fields_3d: set[str] = set()
print("\n[pass2] Accumulating composites ...", flush=True)
for evt in cfg.stages:
evt_dir = cfg.npz_dir / evt
if not evt_dir.exists():
continue
dh_dirs = []
for d in sorted(evt_dir.iterdir()):
if not d.is_dir():
continue
dh_val = _parse_dh(d.name)
if dh_val is not None:
dh_dirs.append((dh_val, d))
dh_dirs.sort(key=lambda x: x[0])
for dh, dh_dir in dh_dirs:
npz_files = sorted(dh_dir.glob("*.npz"))
if not npz_files:
continue
n_total = n_loaded = 0
for fp in npz_files:
n_total += 1
tid = _parse_track_id(fp)
if tid is not None and tid in excluded:
continue
try:
with np.load(fp, allow_pickle=False) as Z:
levels_file = Z["levels"]
# probe 3D field
if "pv_3d" in Z.files:
probe = Z["pv_3d"]
elif "z_3d" in Z.files:
probe = Z["z_3d"]
else:
continue
if LEVELS is None:
LEVELS = levels_file.astype(int).copy()
X_REL = Z["X_rel"]
Y_REL = Z["Y_rel"]
if H_SCALE is None and "H_SCALE" in Z.files:
H_SCALE = float(Z["H_SCALE"])
idx = _levels_indexer(levels_file, LEVELS)
if idx is None:
continue
if probe[idx].ndim != 3:
continue
# Discover & accumulate 3D fields
for k in Z.files:
if k in _META:
continue
a = Z[k]
if a.ndim != 3:
continue
# Skip LS-derived fields
if any(k.startswith(p)
for p in ("prp__", "int__", "ax__",
"ay__", "beta__")):
continue
fields_3d.add(k)
a3 = a[idx]
_accumulate(sums[evt][dh], valids[evt][dh], k, a3)
for var in variants:
if tid in variant_trackset[var]:
_accumulate(
sums_v[var][evt][dh],
valids_v[var][evt][dh],
k, a3,
)
counts[evt][dh] += 1
for var in variants:
if tid in variant_trackset[var]:
counts_v[var][evt][dh] += 1
n_loaded += 1
except Exception:
continue
print(
f"[{evt}] dh={dh:+d}: total={n_total} loaded={n_loaded}",
flush=True,
)
print(f"[pass2] 3D fields discovered: {sorted(fields_3d)}", flush=True)
all_variants = ["original"] + variants
return CompositeResult(
levels=LEVELS if LEVELS is not None else np.array([], dtype=int),
x_rel=X_REL if X_REL is not None else np.array([]),
y_rel=Y_REL if Y_REL is not None else np.array([]),
h_scale=H_SCALE,
stages=list(cfg.stages),
fields_3d=sorted(fields_3d),
sums=dict(sums),
valids=dict(valids),
counts=dict(counts),
sums_v=dict(sums_v),
valids_v=dict(valids_v),
counts_v=dict(counts_v),
variant_names=all_variants,
)