Source code for pvtend.classify

"""RWB (Rossby Wave Breaking) classification of tracked events.

Reads the dh=0 NPZ snapshots produced by :mod:`pvtend.tendency`,
classifies each event as AWB / CWB / NEUTRAL at multiple pressure
levels, and emits a "variant tracksets" PKL that the composite
builder can read.

This corresponds to **Pass 1** of the core script
``ss01_rwb_stage_multilevel_composites.py``.

Usage (via CLI)::

    pvtend-pipeline classify \\
        --npz-dir /path/to/composite_blocking_tempest \\
        --output  /path/to/outputs/rwb_variant_tracksets.pkl \\
        --stages  onset peak decay \\
        --levels  500 400 300 200 \\
        --threshold 3

Or programmatically::

    from pvtend.classify import run_pass1, ClassifyConfig
    cfg = ClassifyConfig(npz_dir=Path("..."))
    result = run_pass1(cfg)
    result.save("rwb_variant_tracksets.pkl")
"""

from __future__ import annotations

import csv
import pickle
import re
from collections import defaultdict
from dataclasses import dataclass, field
from pathlib import Path
from typing import Sequence

import numpy as np

from .rwb import (
    RWBConfig,
    sampled_longest_contours,
    overturn_x_intervals,
    envelope_polygon,
    poly_area_centroid,
    classify_bay,
    centerline_tilt,
)

# ── regex helpers ────────────────────────────────────────────────────
_TRACK_RE = re.compile(r"track_(\d+)_")
_DH_RE = re.compile(r"^dh=([+\-]?\d+)$")


def _parse_track_id(fp: Path) -> int | None:
    m = _TRACK_RE.search(fp.name)
    return int(m.group(1)) if m else None


def _parse_dh(dirname: str) -> int | None:
    m = _DH_RE.match(dirname)
    return int(m.group(1)) if m else None


# ── Config ────────────────────────────────────────────────────────────

[docs] @dataclass class ClassifyConfig: """Configuration for Pass-1 RWB classification. Attributes: npz_dir: Root directory containing stage sub-directories (``onset/``, ``peak/``, ``decay/``), each with ``dh=±N`` subdirectories that hold per-event NPZ files. output_path: Where to save the resulting variant-tracksets PKL. stages: List of event stages to process. classify_levels: Pressure levels [hPa] checked for RWB. classify_threshold: Number of levels that must agree. rwb_cfg: Fine-grained RWB bay-detection settings. exclude_file: Optional CSV listing track IDs to skip. """ npz_dir: Path = Path(".") output_path: Path = Path("rwb_variant_tracksets.pkl") stages: list[str] = field( default_factory=lambda: ["onset", "peak", "decay"] ) classify_levels: list[int | str] = field( default_factory=lambda: [500, 400, 300, 200] ) classify_threshold: int = 3 rwb_cfg: RWBConfig = field( default_factory=lambda: RWBConfig(area_min_deg2=20.0, try_levels=400) ) exclude_file: Path | None = None
# ── Excluded track loader ───────────────────────────────────────────── def _load_excluded(p: Path | None) -> set[int]: ids: set[int] = set() if p is None or not p.exists(): return ids try: with open(p, "r", newline="") as f: sniff = f.read(1024) f.seek(0) if "," in sniff: reader = csv.DictReader(f) col = ("track_id" if reader.fieldnames and "track_id" in reader.fieldnames else None) if col is not None: for row in reader: try: ids.add(int(row[col])) except (ValueError, KeyError): pass else: for line in f: m = re.search(r"\d+", line) if m: ids.add(int(m.group(0))) except Exception: pass return ids # ── Single-level bay classifier ─────────────────────────────────────── def _classify_bays_z2d( z2d: np.ndarray, x_rel: np.ndarray, y_rel: np.ndarray, cfg: RWBConfig, ) -> tuple[bool, bool]: """Detect AWB / CWB bays on one 2-D Z field (relative coords).""" if not np.isfinite(z2d).any(): return False, False x = x_rel[0, :] if x_rel.ndim == 2 else x_rel y = y_rel[:, 0] if y_rel.ndim == 2 else y_rel contours = sampled_longest_contours( z2d, x, y, try_levels=cfg.try_levels, max_keep=12, min_vertices=cfg.min_vertices, ) if not contours: return False, False is_awb = is_cwb = False for c in contours: xline, yline = c["x"], c["y"] intervals = overturn_x_intervals( xline, yline, n_meridians=cfg.n_meridians, min_cross=cfg.min_cross, ) for xa, xb in intervals: poly = envelope_polygon( xline, yline, xa, xb, n_samp=cfg.n_samp, min_points=cfg.min_points, ) if poly is None: continue xp, yp, xm, y_min, y_max = poly area, _ = poly_area_centroid(xp, yp) if abs(area) <= cfg.area_min_deg2: continue wb_type, _ = classify_bay( xline, yline, xa, xb, n_samp=max(80, cfg.n_samp // 2), ) if wb_type == "UNK": slope = centerline_tilt(xm, y_min, y_max) if not np.isfinite(slope): continue wb_type = "AWB" if slope < 0 else "CWB" if wb_type == "AWB": is_awb = True if wb_type == "CWB": is_cwb = True if is_awb and is_cwb: return True, True return is_awb, is_cwb def _classify_multilevel( z3d: np.ndarray | None, levels_file: np.ndarray | None, x_rel: np.ndarray, y_rel: np.ndarray, *, classify_levels: Sequence[int | str], threshold: int, cfg: RWBConfig, z2d_wavg: np.ndarray | None = None, ) -> tuple[bool, bool]: """Multi-level classification; require *threshold* levels to agree. *classify_levels* may contain integer hPa values or the string ``"wavg"``; the latter uses the pre-computed weighted-average 2-D Z field (*z2d_wavg*). """ awb_count = cwb_count = 0 for lev in classify_levels: if isinstance(lev, str) and lev.lower() == "wavg": if z2d_wavg is None: continue awb, cwb = _classify_bays_z2d(z2d_wavg, x_rel, y_rel, cfg) else: if z3d is None or levels_file is None: continue k = int(np.nanargmin(np.abs(levels_file - int(lev)))) if k >= z3d.shape[0]: continue awb, cwb = _classify_bays_z2d(z3d[k], x_rel, y_rel, cfg) awb_count += int(awb) cwb_count += int(cwb) return awb_count >= threshold, cwb_count >= threshold # ── Result container ──────────────────────────────────────────────────
[docs] @dataclass class ClassifyResult: """Holds RWB variant classification results. Attributes: stage_all: ``{stage: set_of_track_ids}`` stage_awb: ``{stage: set_of_AWB_track_ids}`` stage_cwb: ``{stage: set_of_CWB_track_ids}`` stage_neu: ``{stage: set_of_NEUTRAL_track_ids}`` h_scale: Captured from the first NPZ file. stages: Ordered stage names. classify_levels: Pressure levels used. classify_threshold: Threshold used. """ stage_all: dict[str, set[int]] stage_awb: dict[str, set[int]] stage_cwb: dict[str, set[int]] stage_neu: dict[str, set[int]] h_scale: float | None stages: list[str] classify_levels: list[int] classify_threshold: int # ── derived look-ups ── @property def variant_trackset(self) -> dict[str, frozenset[int]]: """Variant → frozenset mapping, e.g. ``AWB_onset``.""" out: dict[str, frozenset[int]] = {} for evt in self.stages: out[f"AWB_{evt}"] = frozenset(self.stage_awb.get(evt, set())) out[f"CWB_{evt}"] = frozenset(self.stage_cwb.get(evt, set())) out[f"NEUTRAL_{evt}"] = frozenset(self.stage_neu.get(evt, set())) return out @property def stage_labels(self) -> dict[str, dict[int, str]]: """``{stage: {track_id: label}}`` where label ∈ AWB/CWB/NEUTRAL/Omega.""" out: dict[str, dict[int, str]] = {} for evt in self.stages: lbl: dict[int, str] = {} amb = self.stage_awb.get(evt, set()) & self.stage_cwb.get(evt, set()) for tid in sorted(self.stage_all.get(evt, set())): if tid in amb: lbl[tid] = "Omega" elif tid in self.stage_awb.get(evt, set()): lbl[tid] = "AWB" elif tid in self.stage_cwb.get(evt, set()): lbl[tid] = "CWB" else: lbl[tid] = "NEUTRAL" out[evt] = lbl return out @property def stage_tracksets(self) -> dict[str, dict[str, frozenset[int]]]: out: dict[str, dict[str, frozenset[int]]] = {} for evt in self.stages: out[evt] = { "ALL": frozenset(self.stage_all.get(evt, set())), "AWB": frozenset(self.stage_awb.get(evt, set())), "CWB": frozenset(self.stage_cwb.get(evt, set())), "NEUTRAL": frozenset(self.stage_neu.get(evt, set())), "Omega": frozenset( self.stage_awb.get(evt, set()) & self.stage_cwb.get(evt, set()) ), } return out # ── I/O ── def save(self, path: Path | str) -> Path: """Persist to pickle (same format as core script).""" path = Path(path) path.parent.mkdir(parents=True, exist_ok=True) data = { "stage_ALL": {k: set(v) for k, v in self.stage_all.items()}, "stage_AWB": {k: set(v) for k, v in self.stage_awb.items()}, "stage_CWB": {k: set(v) for k, v in self.stage_cwb.items()}, "stage_NEU": {k: set(v) for k, v in self.stage_neu.items()}, "H_SCALE": self.h_scale, "variant_trackset": self.variant_trackset, "RWB_STAGE_LABELS": self.stage_labels, "RWB_STAGE_TRACKSETS": self.stage_tracksets, "CLASSIFY_LEVELS": self.classify_levels, "CLASSIFY_THRESHOLD": self.classify_threshold, } with open(path, "wb") as f: pickle.dump(data, f, protocol=pickle.HIGHEST_PROTOCOL) print(f"[saved] RWB variant tracksets → {path}", flush=True) return path @classmethod def load(cls, path: Path | str) -> "ClassifyResult": """Load from a previously-saved PKL.""" path = Path(path) with open(path, "rb") as f: d = pickle.load(f) stages = sorted(d["stage_ALL"].keys()) return cls( stage_all={k: set(v) for k, v in d["stage_ALL"].items()}, stage_awb={k: set(v) for k, v in d["stage_AWB"].items()}, stage_cwb={k: set(v) for k, v in d["stage_CWB"].items()}, stage_neu={k: set(v) for k, v in d["stage_NEU"].items()}, h_scale=d.get("H_SCALE"), stages=stages, classify_levels=d.get("CLASSIFY_LEVELS", [500, 400, 300, 200]), classify_threshold=d.get("CLASSIFY_THRESHOLD", 3), )
# ── Main entry point ──────────────────────────────────────────────────
[docs] def run_pass1(cfg: ClassifyConfig) -> ClassifyResult: """Run Pass-1 RWB classification from NPZ files. Reads ``dh=0`` snapshots under ``cfg.npz_dir/{stage}/dh=+0/`` and classifies each track as AWB / CWB / NEUTRAL. Returns: :class:`ClassifyResult` holding variant sets. """ excluded = _load_excluded(cfg.exclude_file) if excluded: print(f"[exclude] {len(excluded)} track IDs", flush=True) h_scale: float | None = None stage_all: dict[str, set[int]] = {e: set() for e in cfg.stages} stage_awb: dict[str, set[int]] = {e: set() for e in cfg.stages} stage_cwb: dict[str, set[int]] = {e: set() for e in cfg.stages} stage_neu: dict[str, set[int]] = {} _need_wavg = any( isinstance(l, str) and l.lower() == "wavg" for l in cfg.classify_levels ) _need_3d = any( not (isinstance(l, str) and l.lower() == "wavg") for l in cfg.classify_levels ) print(f"\n[pass1] classifying at levels {cfg.classify_levels} " f"(threshold={cfg.classify_threshold})", flush=True) for evt in cfg.stages: evt_dir = cfg.npz_dir / evt if not evt_dir.exists(): continue # Find dh=0 directory dh0_dir = None for cand in ("dh=+0", "dh=0", "dh=-0"): d = evt_dir / cand if d.exists(): dh0_dir = d break if dh0_dir is None: print(f"[warn] no dh=0 directory for {evt}", flush=True) continue npz_files = sorted(dh0_dir.glob("*.npz")) n_ok = n_fail = 0 for fp in npz_files: tid = _parse_track_id(fp) if tid is None or tid in excluded: continue stage_all[evt].add(tid) try: with np.load(fp, allow_pickle=False) as Z: if h_scale is None and "H_SCALE" in Z.files: h_scale = float(Z["H_SCALE"]) x_rel = Z["X_rel"] y_rel = Z["Y_rel"] z3d = None levels_file = None z2d_wavg = None if _need_3d and "z_3d" in Z.files: z3d = Z["z_3d"] levels_file = np.asarray(Z["levels"], dtype=float) if _need_wavg and "z" in Z.files: z2d_wavg = Z["z"] if z3d is None and z2d_wavg is None: continue awb, cwb = _classify_multilevel( z3d, levels_file, x_rel, y_rel, classify_levels=cfg.classify_levels, threshold=cfg.classify_threshold, cfg=cfg.rwb_cfg, z2d_wavg=z2d_wavg, ) if awb: stage_awb[evt].add(tid) if cwb: stage_cwb[evt].add(tid) n_ok += 1 except Exception: n_fail += 1 continue print(f"[classify] {evt}: ok={n_ok} fail={n_fail}", flush=True) for evt in cfg.stages: stage_neu[evt] = stage_all[evt] - ( stage_awb.get(evt, set()) | stage_cwb.get(evt, set()) ) print( f"[classify] {evt}: ALL={len(stage_all[evt])} " f"AWB={len(stage_awb.get(evt, set()))} " f"CWB={len(stage_cwb.get(evt, set()))} " f"NEU={len(stage_neu[evt])}", flush=True, ) return ClassifyResult( stage_all=stage_all, stage_awb=stage_awb, stage_cwb=stage_cwb, stage_neu=stage_neu, h_scale=h_scale, stages=list(cfg.stages), classify_levels=list(cfg.classify_levels), classify_threshold=cfg.classify_threshold, )