718 lines
18 KiB
Python
Raw Normal View History

2025-12-04 10:50:08 -05:00
import csv
import math
from dataclasses import dataclass
from enum import Enum
from pathlib import Path
from typing import Any, Dict, Optional, Sequence, Set
class TreatmentType(str, Enum):
"""Supported treatment rules from the driver dictionaries."""
UNK = "unk"
FIXED = "fixed"
@dataclass(frozen=True)
class VariableRule:
"""Normalised representation of a single driver row."""
var_name: str
data_type: str
valid_min: Optional[float]
valid_max: Optional[float]
observed_cap_min_value: Optional[float]
observed_cap_max_value: Optional[float]
observed_default_treatment_value: Optional[str]
observed_missing_treatment_value: Optional[str]
default_treatment_type: Optional[TreatmentType]
is_predictor: bool
DRIVER_DIR = Path(__file__).parent
DRIVER_MAP: Dict[str, Dict[str, VariableRule]] = {}
MODEL_FEATURE_ALLOWLIST = {
"model_a": (
"AEPMAG05",
"RET201",
"PER201",
"PER202",
"PER222",
"PER225",
"PER235",
"CTM18",
"SC20S",
"AT36SD",
"FI36SD",
"G250BD",
"G250CD",
"US36SD",
"CV13",
"CV25",
"CV26",
"AT01S",
"AT104S",
"FI02S",
"FI20S",
"FI35S",
"G051S",
"G205S",
"G210S",
"G218A",
"G225S",
"G230S",
"G234S",
"G300S",
"IN02S",
"IN12S",
"OF20S",
"RT20S",
"INAP01",
"G106S",
"US02S",
"US20S",
"US24S",
"US28S",
"US32S",
"US35S",
"US36S",
"SE20S",
"US51A",
"G205B",
"INST_TRD",
"RTL_TRD",
"AGG402",
"AGG403",
"AGG423",
"AGG424",
"AGG903",
"TRV03",
"TRV04",
"BALMAG01",
"score_results",
),
"model_b": (
"UTLMAG01",
"AEPMAG04",
"PER201",
"PER203",
"PER222",
"PER223",
"PER224",
"PER225",
"PER235",
"CTM23",
"CT321",
"CTC20",
"CTA17",
"CTA18",
"SC21S",
"SCC92",
"SCBALM01",
"AT36SD",
"FI36SD",
"RE36SD",
"SE36SD",
"US36SD",
"LQA232YR",
"LQR325YR",
"RLE902",
"CV25",
"CV26",
"RVDEXQ2",
"AT01S",
"AT104S",
"AU20S",
"BI21S",
"BR33S",
"CO06S",
"FI02S",
"FI03S",
"FI20S",
"FI32S",
"FI33S",
"FI34S",
"FI35S",
"FI101S",
"FR21S",
"FR32S",
"G020S",
"G102S",
"G205S",
"G210S",
"G213A",
"G225S",
"G234S",
"G301S",
"G990S",
"IN02S",
"IN12S",
"MT21S",
"OF09S",
"OF21S",
"OF29S",
"OF35S",
"RE32S",
"RT36S",
"ST01S",
"INAP01",
"G106S",
"S204S",
"US02S",
"US03S",
"US12S",
"US20S",
"US24S",
"US30S",
"US34S",
"SE20S",
"SE21S",
"SE34S",
"SE36S",
"JT20S",
"JT33S",
"JT70S",
"G404S",
"G405S",
"G406S",
"G407S",
"G416S",
"G417S",
"US51A",
"INST_TRD",
"NOMT_TRD",
"AGG512",
"AGG516",
"AGG902",
"AGG903",
"TRV03",
"TRV04",
"TRV06",
"BALMAG01",
"RVLR14",
"PAYMNT06",
"PAYMNT07",
"score_results",
),
"model_t": (
"PDMAG01",
"AEPMAG05",
"AUT201",
"PER201",
"PER203",
"PER204",
"PER205",
"PER223",
"PER225",
"PER253",
"CTA17",
"SE21CD",
"RLE907",
"CV26",
"AT35B",
"FI28S",
"FI32S",
"INAP01",
"US01S",
"US28S",
"US34S",
"US101S",
"SE02S",
"SE06S",
"SE09S",
"SE20S",
"TRV06",
"TRV10",
"PAYMNT06",
),
}
MODEL_OUTPUT_FEATURE_ALLOWLIST = {
"model_a_features": MODEL_FEATURE_ALLOWLIST["model_a"] + ("PER201_unk", "G225S_unk", "SC20S_unk", "RET201_unk", "US24S_unk"),
"model_b_features": MODEL_FEATURE_ALLOWLIST["model_b"],
"model_t_features": MODEL_FEATURE_ALLOWLIST["model_t"] + ("AEPMAG05_unk", "PER201_unk"),
}
def _to_float(candidate: Optional[str]) -> Optional[float]:
if candidate in (None, "", "null", "None"):
return None
try:
return float(candidate)
except (TypeError, ValueError):
return None
def _to_bool(candidate: Optional[str]) -> bool:
return str(candidate).strip() in {"1", "1.0", "true", "True"}
def _normalise_data_type(data_type: str) -> str:
if not data_type:
return "string"
data_type = data_type.strip().lower()
if data_type in {"float", "double"}:
return "float"
if data_type in {"int", "integer"}:
return "int"
return "string"
def _load_driver(name: str, allowed_features: Optional[Set[str]] = None) -> Dict[str, VariableRule]:
path = DRIVER_DIR / name
with path.open(newline="", encoding="utf-8") as csv_file:
reader = csv.DictReader(csv_file)
metadata: Dict[str, VariableRule] = {}
for row in reader:
if not _to_bool(row.get("is_predictor")):
continue
var_name = row["var_name"]
if allowed_features is not None and var_name not in allowed_features:
continue
treatment_type_raw = (row.get("default_treatment_type") or "").strip().lower()
treatment_type: Optional[TreatmentType]
if treatment_type_raw:
treatment_type = TreatmentType(treatment_type_raw)
else:
treatment_type = None
rule = VariableRule(
var_name=var_name,
data_type=_normalise_data_type(row.get("data_type", "")),
valid_min=_to_float(row.get("valid_min")),
valid_max=_to_float(row.get("valid_max")),
observed_cap_min_value=_to_float(row.get("observed_cap_min_value")),
observed_cap_max_value=_to_float(row.get("observed_cap_max_value")),
observed_default_treatment_value=row.get("observed_default_treatment_value") or None,
observed_missing_treatment_value=row.get("observed_missing_treatment_value") or None,
default_treatment_type=treatment_type,
is_predictor=True,
)
metadata[rule.var_name] = rule
return metadata
def _ensure_driver_map() -> None:
if DRIVER_MAP:
return
DRIVER_MAP["model_a"] = _load_driver("data_dictionary_updated_A.csv", set(MODEL_FEATURE_ALLOWLIST["model_a"]))
DRIVER_MAP["model_b"] = _load_driver("data_dictionary_updated_B.csv", set(MODEL_FEATURE_ALLOWLIST["model_b"]))
DRIVER_MAP["model_t"] = _load_driver("data_dictionary_updated_T.csv", set(MODEL_FEATURE_ALLOWLIST["model_t"]))
def _is_number(value: Any) -> bool:
return isinstance(value, (int, float)) and not isinstance(value, bool)
def _is_missing(value: Any) -> bool:
if value is None:
return True
if isinstance(value, float) and math.isnan(value):
return True
if isinstance(value, str) and not value.strip():
return True
return False
def _safe_cast(value: Any, data_type: str) -> Any:
if value in (None, "", "null", "NULL"):
return None
try:
if data_type == "int":
if isinstance(value, str):
value = value.strip()
return int(float(value))
if data_type == "float":
return float(value)
return str(value)
except (TypeError, ValueError):
return None
def _value_out_of_range(value: Any, minimum: Optional[float], maximum: Optional[float]) -> bool:
if not _is_number(value):
return False
if minimum is not None and value < minimum:
return True
if maximum is not None and value > maximum:
return True
return False
def _process_rule(
raw_record: Dict[str, Any],
processed_record: Dict[str, Any],
rule: VariableRule,
lookup: Dict[str, VariableRule],
) -> None:
if rule.var_name.endswith("_unk"):
# Engineered flag is handled when the base variable is processed.
return
value = _safe_cast(raw_record.get(rule.var_name), rule.data_type)
out_of_range = _value_out_of_range(value, rule.valid_min, rule.valid_max)
null_due_to_unk_outlier = False
is_missing_value = _is_missing(value)
if rule.default_treatment_type == TreatmentType.UNK:
unk_name = f"{rule.var_name}_unk"
if unk_name not in processed_record:
processed_record[unk_name] = 0
if out_of_range or is_missing_value:
processed_record[unk_name] = 1
value = None
null_due_to_unk_outlier = True
else:
processed_record[unk_name] = 0
elif rule.default_treatment_type == TreatmentType.FIXED and out_of_range:
value = None
if _is_number(value):
if rule.observed_cap_min_value is not None and value < rule.observed_cap_min_value:
value = rule.observed_cap_min_value
if rule.observed_cap_max_value is not None and value > rule.observed_cap_max_value:
value = rule.observed_cap_max_value
if not null_due_to_unk_outlier and _value_out_of_range(value, rule.valid_min, rule.valid_max):
default_value = rule.observed_default_treatment_value
if default_value is not None:
value = _safe_cast(default_value, rule.data_type)
if not null_due_to_unk_outlier and _is_missing(value):
missing_value = rule.observed_missing_treatment_value
if missing_value is not None:
value = _safe_cast(missing_value, rule.data_type)
processed_record[rule.var_name] = value
def _process_model(
record: Dict[str, Any],
lookup: Dict[str, VariableRule],
allowed_features: Optional[Sequence[str]] = None,
) -> Dict[str, Any]:
processed: Dict[str, Any] = {}
for rule in lookup.values():
_process_rule(record, processed, rule, lookup)
# Ensure engineered _unk fields default to 0 when not explicitly set.
for var_name, rule in lookup.items():
if var_name.endswith("_unk") and var_name not in processed and rule.is_predictor:
processed[var_name] = 0
if allowed_features is not None:
processed = {key: processed.get(key) for key in allowed_features}
return processed
def __main__(
AEPMAG04: int,
AEPMAG05: int,
AGG402: int,
AGG403: int,
AGG423: int,
AGG424: int,
AGG512: int,
AGG516: int,
AGG902: int,
AGG903: int,
AT01S: int,
AT104S: int,
AT35B: int,
AT36SD: int,
AU20S: int,
AUT201: float,
BALMAG01: int,
BI21S: int,
BR33S: int,
CO06S: int,
CT321: int,
CTA17: int,
CTA18: int,
CTC20: int,
CTM18: int,
CTM23: int,
CV13: float,
CV25: float,
CV26: float,
FI02S: int,
FI03S: int,
FI101S: int,
FI20S: int,
FI28S: int,
FI32S: int,
FI33S: int,
FI34S: int,
FI35S: int,
FI36SD: int,
FR21S: int,
FR32S: int,
G020S: int,
G051S: int,
G102S: int,
G106S: int,
G205B: int,
G205S: int,
G210S: int,
G213A: int,
G218A: int,
G225S: int,
G230S: int,
G234S: int,
G250BD: int,
G250CD: int,
G300S: int,
G301S: int,
G404S: int,
G405S: int,
G406S: int,
G407S: int,
G416S: int,
G417S: int,
G990S: int,
IN02S: int,
IN12S: int,
INAP01: int,
INST_TRD: int,
JT20S: int,
JT33S: int,
JT70S: int,
LQA232YR: float,
LQR325YR: float,
MT21S: int,
NOMT_TRD: int,
OF09S: int,
OF20S: int,
OF21S: int,
OF29S: int,
OF35S: int,
PAYMNT06: float,
PAYMNT07: float,
PDMAG01: int,
PER201: float,
PER202: float,
PER203: float,
PER204: float,
PER205: float,
PER222: float,
PER223: float,
PER224: float,
PER225: float,
PER235: int,
PER253: int,
RE32S: int,
RE36SD: int,
RET201: float,
RLE902: int,
RLE907: int,
RT20S: int,
RT36S: int,
RTL_TRD: int,
RVDEXQ2: int,
RVLR14: str,
S204S: int,
SC20S: int,
SC21S: int,
SCBALM01: int,
SCC92: int,
score_results: int,
SE02S: int,
SE06S: int,
SE09S: int,
SE20S: int,
SE21CD: int,
SE21S: int,
SE34S: int,
SE36S: int,
SE36SD: int,
ST01S: int,
TRV03: int,
TRV04: int,
TRV06: int,
TRV10: int,
US01S: int,
US02S: int,
US03S: int,
US101S: int,
US12S: int,
US20S: int,
US24S: int,
US28S: int,
US30S: float,
US32S: int,
US34S: int,
US35S: int,
US36S: int,
US36SD: int,
US51A: int,
UTLMAG01: int,
) -> dict:
record = {
"AEPMAG04": AEPMAG04,
"AEPMAG05": AEPMAG05,
"AGG402": AGG402,
"AGG403": AGG403,
"AGG423": AGG423,
"AGG424": AGG424,
"AGG512": AGG512,
"AGG516": AGG516,
"AGG902": AGG902,
"AGG903": AGG903,
"AT01S": AT01S,
"AT104S": AT104S,
"AT35B": AT35B,
"AT36SD": AT36SD,
"AU20S": AU20S,
"AUT201": AUT201,
"BALMAG01": BALMAG01,
"BI21S": BI21S,
"BR33S": BR33S,
"CO06S": CO06S,
"CT321": CT321,
"CTA17": CTA17,
"CTA18": CTA18,
"CTC20": CTC20,
"CTM18": CTM18,
"CTM23": CTM23,
"CV13": CV13,
"CV25": CV25,
"CV26": CV26,
"FI02S": FI02S,
"FI03S": FI03S,
"FI101S": FI101S,
"FI20S": FI20S,
"FI28S": FI28S,
"FI32S": FI32S,
"FI33S": FI33S,
"FI34S": FI34S,
"FI35S": FI35S,
"FI36SD": FI36SD,
"FR21S": FR21S,
"FR32S": FR32S,
"G020S": G020S,
"G051S": G051S,
"G102S": G102S,
"G106S": G106S,
"G205B": G205B,
"G205S": G205S,
"G210S": G210S,
"G213A": G213A,
"G218A": G218A,
"G225S": G225S,
"G230S": G230S,
"G234S": G234S,
"G250BD": G250BD,
"G250CD": G250CD,
"G300S": G300S,
"G301S": G301S,
"G404S": G404S,
"G405S": G405S,
"G406S": G406S,
"G407S": G407S,
"G416S": G416S,
"G417S": G417S,
"G990S": G990S,
"IN02S": IN02S,
"IN12S": IN12S,
"INAP01": INAP01,
"INST_TRD": INST_TRD,
"JT20S": JT20S,
"JT33S": JT33S,
"JT70S": JT70S,
"LQA232YR": LQA232YR,
"LQR325YR": LQR325YR,
"MT21S": MT21S,
"NOMT_TRD": NOMT_TRD,
"OF09S": OF09S,
"OF20S": OF20S,
"OF21S": OF21S,
"OF29S": OF29S,
"OF35S": OF35S,
"PAYMNT06": PAYMNT06,
"PAYMNT07": PAYMNT07,
"PDMAG01": PDMAG01,
"PER201": PER201,
"PER202": PER202,
"PER203": PER203,
"PER204": PER204,
"PER205": PER205,
"PER222": PER222,
"PER223": PER223,
"PER224": PER224,
"PER225": PER225,
"PER235": PER235,
"PER253": PER253,
"RE32S": RE32S,
"RE36SD": RE36SD,
"RET201": RET201,
"RLE902": RLE902,
"RLE907": RLE907,
"RT20S": RT20S,
"RT36S": RT36S,
"RTL_TRD": RTL_TRD,
"RVDEXQ2": RVDEXQ2,
"RVLR14": RVLR14,
"S204S": S204S,
"SC20S": SC20S,
"SC21S": SC21S,
"SCBALM01": SCBALM01,
"SCC92": SCC92,
"score_results": score_results,
"SE02S": SE02S,
"SE06S": SE06S,
"SE09S": SE09S,
"SE20S": SE20S,
"SE21CD": SE21CD,
"SE21S": SE21S,
"SE34S": SE34S,
"SE36S": SE36S,
"SE36SD": SE36SD,
"ST01S": ST01S,
"TRV03": TRV03,
"TRV04": TRV04,
"TRV06": TRV06,
"TRV10": TRV10,
"US01S": US01S,
"US02S": US02S,
"US03S": US03S,
"US101S": US101S,
"US12S": US12S,
"US20S": US20S,
"US24S": US24S,
"US28S": US28S,
"US30S": US30S,
"US32S": US32S,
"US34S": US34S,
"US35S": US35S,
"US36S": US36S,
"US36SD": US36SD,
"US51A": US51A,
"UTLMAG01": UTLMAG01
}
_ensure_driver_map()
return {"results": [
{"model_a_features": _process_model(
record,
DRIVER_MAP["model_a"],
MODEL_OUTPUT_FEATURE_ALLOWLIST["model_a_features"],
),
"model_b_features": _process_model(
record,
DRIVER_MAP["model_b"],
MODEL_OUTPUT_FEATURE_ALLOWLIST["model_b_features"],
),
"model_t_features": _process_model(
record,
DRIVER_MAP["model_t"],
MODEL_OUTPUT_FEATURE_ALLOWLIST["model_t_features"]
)}]
}