"""NPZ composite accumulation and PKL export/load.
Accumulates per-event NPZ patch fields into running sums and valid-point
counts, grouped by event stage (onset/peak/decay) and RWB variant
(original/AWB_onset/CWB_peak/etc.). The accumulated state can be exported
to a pickle file for rapid loading in analysis notebooks.
"""
from __future__ import annotations
import math
import pickle
from collections import defaultdict
from dataclasses import dataclass
from pathlib import Path
from typing import Mapping, MutableMapping
import numpy as np
[docs]
@dataclass(frozen=True)
class CompositeState:
"""Lightweight wrapper around exported composite globals.
Provides `composite_mean_3d()` and `composite_reduce()` methods to
retrieve composite-mean fields from the accumulated sums.
"""
globals_map: Mapping[str, object]
def __post_init__(self) -> None:
required = {
"FIELDS3D", "LEVELS", "SUMS3D", "VALID3D",
"FILECOUNT", "SUMS3D_V", "VALID3D_V",
"FILECOUNT_V", "COMPOSITE_VARIANTS",
}
missing = sorted(name for name in required if name not in self.globals_map)
if missing:
raise KeyError(f"Composite globals missing: {missing}")
def list_fields3d(self) -> tuple[str, ...]:
"""Return the tuple of 3D field names."""
return tuple(self.globals_map.get("FIELDS3D", ()))
def composite_mean_3d(
self,
field: str,
stage: str,
dh: int,
*,
variant: str | None = "original",
) -> np.ndarray | None:
"""Return composite mean of a 3D field.
Parameters:
field: Field name (e.g., 'pv_3d', 'z_3d').
stage: Event stage name (e.g., 'onset', 'peak', 'decay').
dh: Hour offset from stage reference time.
variant: Composite variant key (default 'original').
Returns:
3D numpy array of composite-mean values, or None if unavailable.
"""
variant_key = self._norm_variant(variant)
stage_key = self._norm_stage(stage)
sums, valids = self._pick_store3d(variant_key, stage_key, int(dh))
if sums is None or valids is None:
return None
arr_sum = _safe_lookup(sums, field)
vcount = _safe_lookup(valids, field)
if arr_sum is None or vcount is None:
return None
arr_sum = np.asarray(arr_sum, dtype=np.float64)
vcount = np.asarray(vcount, dtype=np.float64)
out = np.full_like(arr_sum, np.nan, dtype=np.float64)
mask = vcount > 0
with np.errstate(invalid="ignore", divide="ignore"):
np.divide(arr_sum, vcount, out=out, where=mask)
return out
def composite_reduce(
self,
field: str,
stage: str,
dh: int,
*,
variant: str | None = "original",
level_mode=None,
) -> np.ndarray | None:
"""Get composite-mean field, optionally reduced to 2D.
Parameters:
field: Field name (e.g., 'pv_3d', 'z_3d').
stage: Event stage name.
dh: Hour offset from stage reference time.
variant: Composite variant key (default 'original').
level_mode: One of None/'all'/'3d' (return full 3D),
'wavg'/'weighted' (height-weighted vertical mean),
or a numeric pressure level in hPa.
Returns:
Numpy array (3D or 2D depending on level_mode), or None.
"""
arr3d = self.composite_mean_3d(field, stage, dh, variant=variant)
if arr3d is None:
return None
if level_mode in (None, "", "all", "3d"):
return np.array(arr3d, copy=True)
if isinstance(level_mode, str) and level_mode.lower() in {"wavg", "weighted"}:
z3d = self.composite_mean_3d(
self._resolve_z_name(), stage, dh, variant=variant)
h_scale = float(self.globals_map.get("H_SCALE", 7000.0))
if z3d is None:
raise ValueError("No z_3d for weighted average")
w = np.exp(-z3d / h_scale)
num = np.nansum(w * arr3d, axis=0)
den = np.nansum(w, axis=0)
out = np.full_like(num, np.nan)
out[den > 0] = num[den > 0] / den[den > 0]
return out
level_val = float(level_mode)
levels = np.asarray(self.globals_map.get("LEVELS", ()), dtype=float)
idx = int(np.nanargmin(np.abs(levels - level_val)))
return arr3d[idx]
def _norm_stage(self, stage: str) -> str:
"""Normalize stage name to match stored keys (case-insensitive)."""
for k in self._stage_names:
if k.lower() == stage.strip().lower():
return k
raise KeyError(f"Unknown stage {stage!r}")
def _norm_variant(self, variant: str | None) -> str:
"""Normalize variant name to match stored keys (case-insensitive)."""
if not variant:
return "original"
for c in self.globals_map.get("COMPOSITE_VARIANTS", ()):
if str(c).lower() == str(variant).strip().lower():
return str(c)
raise KeyError(f"Unknown variant {variant!r}")
def _pick_store3d(self, variant, stage, dh):
"""Select the appropriate sums/valids dicts for a variant+stage+dh."""
if variant == "original":
sums_map = self.globals_map.get("SUMS3D", {})
valids_map = self.globals_map.get("VALID3D", {})
else:
sums_map = self.globals_map.get("SUMS3D_V", {}).get(variant, {})
valids_map = self.globals_map.get("VALID3D_V", {}).get(variant, {})
s_stage = _safe_lookup(sums_map, stage)
v_stage = _safe_lookup(valids_map, stage)
if s_stage is None or v_stage is None:
return None, None
return _safe_lookup(s_stage, dh), _safe_lookup(v_stage, dh)
def _resolve_z_name(self) -> str:
"""Resolve the geopotential height field name."""
fields = self.list_fields3d()
if "z_3d" in fields:
return "z_3d"
if "z" in fields:
return "z"
raise KeyError("No z_3d/z field found")
@property
def _stage_names(self) -> tuple[str, ...]:
"""Return stage names from the FILECOUNT dict keys."""
fc = self.globals_map.get("FILECOUNT", {})
return tuple(fc.keys()) if isinstance(fc, dict) else ()
def _safe_lookup(mapping, key):
"""Safely look up a key in a mapping-like object.
Parameters:
mapping: A dict, Mapping, or object with a .get() method.
key: The key to look up.
Returns:
The value, or None if not found or mapping is None.
"""
if mapping is None:
return None
if isinstance(mapping, Mapping):
return mapping.get(key)
if hasattr(mapping, "get"):
return mapping.get(key)
return None
[docs]
def load_composite_state(path: str | Path) -> CompositeState:
"""Load CompositeState from a pickle file.
Parameters:
path: Path to the exported pickle.
Returns:
CompositeState instance.
"""
with open(path, "rb") as fh:
data = pickle.load(fh)
if isinstance(data, dict) and "globals" in data and len(data) == 1:
data = data["globals"]
return CompositeState(globals_map=data)
def save_composite_state(
state_dict: Mapping[str, object],
path: str | Path,
) -> Path:
"""Save composite state as pickle (highest protocol).
Parameters:
state_dict: Dictionary of composite globals.
path: Output path.
Returns:
Path to written file.
"""
path = Path(path)
path.parent.mkdir(parents=True, exist_ok=True)
with path.open("wb") as fh:
pickle.dump({"globals": dict(state_dict)}, fh,
protocol=pickle.HIGHEST_PROTOCOL)
return path