Ankur Malik d94d65ed2d
All checks were successful
Build and Push Docker Image / test (push) Successful in 1m50s
Build and Push Docker Image / build_and_push (push) Successful in 3m7s
Add pd v3 pre processing block
2025-12-04 10:50:08 -05:00

718 lines
18 KiB
Python

import csv
import math
from dataclasses import dataclass
from enum import Enum
from pathlib import Path
from typing import Any, Dict, Optional, Sequence, Set
class TreatmentType(str, Enum):
"""Supported treatment rules from the driver dictionaries."""
UNK = "unk"
FIXED = "fixed"
@dataclass(frozen=True)
class VariableRule:
"""Normalised representation of a single driver row."""
var_name: str
data_type: str
valid_min: Optional[float]
valid_max: Optional[float]
observed_cap_min_value: Optional[float]
observed_cap_max_value: Optional[float]
observed_default_treatment_value: Optional[str]
observed_missing_treatment_value: Optional[str]
default_treatment_type: Optional[TreatmentType]
is_predictor: bool
DRIVER_DIR = Path(__file__).parent
DRIVER_MAP: Dict[str, Dict[str, VariableRule]] = {}
MODEL_FEATURE_ALLOWLIST = {
"model_a": (
"AEPMAG05",
"RET201",
"PER201",
"PER202",
"PER222",
"PER225",
"PER235",
"CTM18",
"SC20S",
"AT36SD",
"FI36SD",
"G250BD",
"G250CD",
"US36SD",
"CV13",
"CV25",
"CV26",
"AT01S",
"AT104S",
"FI02S",
"FI20S",
"FI35S",
"G051S",
"G205S",
"G210S",
"G218A",
"G225S",
"G230S",
"G234S",
"G300S",
"IN02S",
"IN12S",
"OF20S",
"RT20S",
"INAP01",
"G106S",
"US02S",
"US20S",
"US24S",
"US28S",
"US32S",
"US35S",
"US36S",
"SE20S",
"US51A",
"G205B",
"INST_TRD",
"RTL_TRD",
"AGG402",
"AGG403",
"AGG423",
"AGG424",
"AGG903",
"TRV03",
"TRV04",
"BALMAG01",
"score_results",
),
"model_b": (
"UTLMAG01",
"AEPMAG04",
"PER201",
"PER203",
"PER222",
"PER223",
"PER224",
"PER225",
"PER235",
"CTM23",
"CT321",
"CTC20",
"CTA17",
"CTA18",
"SC21S",
"SCC92",
"SCBALM01",
"AT36SD",
"FI36SD",
"RE36SD",
"SE36SD",
"US36SD",
"LQA232YR",
"LQR325YR",
"RLE902",
"CV25",
"CV26",
"RVDEXQ2",
"AT01S",
"AT104S",
"AU20S",
"BI21S",
"BR33S",
"CO06S",
"FI02S",
"FI03S",
"FI20S",
"FI32S",
"FI33S",
"FI34S",
"FI35S",
"FI101S",
"FR21S",
"FR32S",
"G020S",
"G102S",
"G205S",
"G210S",
"G213A",
"G225S",
"G234S",
"G301S",
"G990S",
"IN02S",
"IN12S",
"MT21S",
"OF09S",
"OF21S",
"OF29S",
"OF35S",
"RE32S",
"RT36S",
"ST01S",
"INAP01",
"G106S",
"S204S",
"US02S",
"US03S",
"US12S",
"US20S",
"US24S",
"US30S",
"US34S",
"SE20S",
"SE21S",
"SE34S",
"SE36S",
"JT20S",
"JT33S",
"JT70S",
"G404S",
"G405S",
"G406S",
"G407S",
"G416S",
"G417S",
"US51A",
"INST_TRD",
"NOMT_TRD",
"AGG512",
"AGG516",
"AGG902",
"AGG903",
"TRV03",
"TRV04",
"TRV06",
"BALMAG01",
"RVLR14",
"PAYMNT06",
"PAYMNT07",
"score_results",
),
"model_t": (
"PDMAG01",
"AEPMAG05",
"AUT201",
"PER201",
"PER203",
"PER204",
"PER205",
"PER223",
"PER225",
"PER253",
"CTA17",
"SE21CD",
"RLE907",
"CV26",
"AT35B",
"FI28S",
"FI32S",
"INAP01",
"US01S",
"US28S",
"US34S",
"US101S",
"SE02S",
"SE06S",
"SE09S",
"SE20S",
"TRV06",
"TRV10",
"PAYMNT06",
),
}
MODEL_OUTPUT_FEATURE_ALLOWLIST = {
"model_a_features": MODEL_FEATURE_ALLOWLIST["model_a"] + ("PER201_unk", "G225S_unk", "SC20S_unk", "RET201_unk", "US24S_unk"),
"model_b_features": MODEL_FEATURE_ALLOWLIST["model_b"],
"model_t_features": MODEL_FEATURE_ALLOWLIST["model_t"] + ("AEPMAG05_unk", "PER201_unk"),
}
def _to_float(candidate: Optional[str]) -> Optional[float]:
if candidate in (None, "", "null", "None"):
return None
try:
return float(candidate)
except (TypeError, ValueError):
return None
def _to_bool(candidate: Optional[str]) -> bool:
return str(candidate).strip() in {"1", "1.0", "true", "True"}
def _normalise_data_type(data_type: str) -> str:
if not data_type:
return "string"
data_type = data_type.strip().lower()
if data_type in {"float", "double"}:
return "float"
if data_type in {"int", "integer"}:
return "int"
return "string"
def _load_driver(name: str, allowed_features: Optional[Set[str]] = None) -> Dict[str, VariableRule]:
path = DRIVER_DIR / name
with path.open(newline="", encoding="utf-8") as csv_file:
reader = csv.DictReader(csv_file)
metadata: Dict[str, VariableRule] = {}
for row in reader:
if not _to_bool(row.get("is_predictor")):
continue
var_name = row["var_name"]
if allowed_features is not None and var_name not in allowed_features:
continue
treatment_type_raw = (row.get("default_treatment_type") or "").strip().lower()
treatment_type: Optional[TreatmentType]
if treatment_type_raw:
treatment_type = TreatmentType(treatment_type_raw)
else:
treatment_type = None
rule = VariableRule(
var_name=var_name,
data_type=_normalise_data_type(row.get("data_type", "")),
valid_min=_to_float(row.get("valid_min")),
valid_max=_to_float(row.get("valid_max")),
observed_cap_min_value=_to_float(row.get("observed_cap_min_value")),
observed_cap_max_value=_to_float(row.get("observed_cap_max_value")),
observed_default_treatment_value=row.get("observed_default_treatment_value") or None,
observed_missing_treatment_value=row.get("observed_missing_treatment_value") or None,
default_treatment_type=treatment_type,
is_predictor=True,
)
metadata[rule.var_name] = rule
return metadata
def _ensure_driver_map() -> None:
if DRIVER_MAP:
return
DRIVER_MAP["model_a"] = _load_driver("data_dictionary_updated_A.csv", set(MODEL_FEATURE_ALLOWLIST["model_a"]))
DRIVER_MAP["model_b"] = _load_driver("data_dictionary_updated_B.csv", set(MODEL_FEATURE_ALLOWLIST["model_b"]))
DRIVER_MAP["model_t"] = _load_driver("data_dictionary_updated_T.csv", set(MODEL_FEATURE_ALLOWLIST["model_t"]))
def _is_number(value: Any) -> bool:
return isinstance(value, (int, float)) and not isinstance(value, bool)
def _is_missing(value: Any) -> bool:
if value is None:
return True
if isinstance(value, float) and math.isnan(value):
return True
if isinstance(value, str) and not value.strip():
return True
return False
def _safe_cast(value: Any, data_type: str) -> Any:
if value in (None, "", "null", "NULL"):
return None
try:
if data_type == "int":
if isinstance(value, str):
value = value.strip()
return int(float(value))
if data_type == "float":
return float(value)
return str(value)
except (TypeError, ValueError):
return None
def _value_out_of_range(value: Any, minimum: Optional[float], maximum: Optional[float]) -> bool:
if not _is_number(value):
return False
if minimum is not None and value < minimum:
return True
if maximum is not None and value > maximum:
return True
return False
def _process_rule(
raw_record: Dict[str, Any],
processed_record: Dict[str, Any],
rule: VariableRule,
lookup: Dict[str, VariableRule],
) -> None:
if rule.var_name.endswith("_unk"):
# Engineered flag is handled when the base variable is processed.
return
value = _safe_cast(raw_record.get(rule.var_name), rule.data_type)
out_of_range = _value_out_of_range(value, rule.valid_min, rule.valid_max)
null_due_to_unk_outlier = False
is_missing_value = _is_missing(value)
if rule.default_treatment_type == TreatmentType.UNK:
unk_name = f"{rule.var_name}_unk"
if unk_name not in processed_record:
processed_record[unk_name] = 0
if out_of_range or is_missing_value:
processed_record[unk_name] = 1
value = None
null_due_to_unk_outlier = True
else:
processed_record[unk_name] = 0
elif rule.default_treatment_type == TreatmentType.FIXED and out_of_range:
value = None
if _is_number(value):
if rule.observed_cap_min_value is not None and value < rule.observed_cap_min_value:
value = rule.observed_cap_min_value
if rule.observed_cap_max_value is not None and value > rule.observed_cap_max_value:
value = rule.observed_cap_max_value
if not null_due_to_unk_outlier and _value_out_of_range(value, rule.valid_min, rule.valid_max):
default_value = rule.observed_default_treatment_value
if default_value is not None:
value = _safe_cast(default_value, rule.data_type)
if not null_due_to_unk_outlier and _is_missing(value):
missing_value = rule.observed_missing_treatment_value
if missing_value is not None:
value = _safe_cast(missing_value, rule.data_type)
processed_record[rule.var_name] = value
def _process_model(
record: Dict[str, Any],
lookup: Dict[str, VariableRule],
allowed_features: Optional[Sequence[str]] = None,
) -> Dict[str, Any]:
processed: Dict[str, Any] = {}
for rule in lookup.values():
_process_rule(record, processed, rule, lookup)
# Ensure engineered _unk fields default to 0 when not explicitly set.
for var_name, rule in lookup.items():
if var_name.endswith("_unk") and var_name not in processed and rule.is_predictor:
processed[var_name] = 0
if allowed_features is not None:
processed = {key: processed.get(key) for key in allowed_features}
return processed
def __main__(
AEPMAG04: int,
AEPMAG05: int,
AGG402: int,
AGG403: int,
AGG423: int,
AGG424: int,
AGG512: int,
AGG516: int,
AGG902: int,
AGG903: int,
AT01S: int,
AT104S: int,
AT35B: int,
AT36SD: int,
AU20S: int,
AUT201: float,
BALMAG01: int,
BI21S: int,
BR33S: int,
CO06S: int,
CT321: int,
CTA17: int,
CTA18: int,
CTC20: int,
CTM18: int,
CTM23: int,
CV13: float,
CV25: float,
CV26: float,
FI02S: int,
FI03S: int,
FI101S: int,
FI20S: int,
FI28S: int,
FI32S: int,
FI33S: int,
FI34S: int,
FI35S: int,
FI36SD: int,
FR21S: int,
FR32S: int,
G020S: int,
G051S: int,
G102S: int,
G106S: int,
G205B: int,
G205S: int,
G210S: int,
G213A: int,
G218A: int,
G225S: int,
G230S: int,
G234S: int,
G250BD: int,
G250CD: int,
G300S: int,
G301S: int,
G404S: int,
G405S: int,
G406S: int,
G407S: int,
G416S: int,
G417S: int,
G990S: int,
IN02S: int,
IN12S: int,
INAP01: int,
INST_TRD: int,
JT20S: int,
JT33S: int,
JT70S: int,
LQA232YR: float,
LQR325YR: float,
MT21S: int,
NOMT_TRD: int,
OF09S: int,
OF20S: int,
OF21S: int,
OF29S: int,
OF35S: int,
PAYMNT06: float,
PAYMNT07: float,
PDMAG01: int,
PER201: float,
PER202: float,
PER203: float,
PER204: float,
PER205: float,
PER222: float,
PER223: float,
PER224: float,
PER225: float,
PER235: int,
PER253: int,
RE32S: int,
RE36SD: int,
RET201: float,
RLE902: int,
RLE907: int,
RT20S: int,
RT36S: int,
RTL_TRD: int,
RVDEXQ2: int,
RVLR14: str,
S204S: int,
SC20S: int,
SC21S: int,
SCBALM01: int,
SCC92: int,
score_results: int,
SE02S: int,
SE06S: int,
SE09S: int,
SE20S: int,
SE21CD: int,
SE21S: int,
SE34S: int,
SE36S: int,
SE36SD: int,
ST01S: int,
TRV03: int,
TRV04: int,
TRV06: int,
TRV10: int,
US01S: int,
US02S: int,
US03S: int,
US101S: int,
US12S: int,
US20S: int,
US24S: int,
US28S: int,
US30S: float,
US32S: int,
US34S: int,
US35S: int,
US36S: int,
US36SD: int,
US51A: int,
UTLMAG01: int,
) -> dict:
record = {
"AEPMAG04": AEPMAG04,
"AEPMAG05": AEPMAG05,
"AGG402": AGG402,
"AGG403": AGG403,
"AGG423": AGG423,
"AGG424": AGG424,
"AGG512": AGG512,
"AGG516": AGG516,
"AGG902": AGG902,
"AGG903": AGG903,
"AT01S": AT01S,
"AT104S": AT104S,
"AT35B": AT35B,
"AT36SD": AT36SD,
"AU20S": AU20S,
"AUT201": AUT201,
"BALMAG01": BALMAG01,
"BI21S": BI21S,
"BR33S": BR33S,
"CO06S": CO06S,
"CT321": CT321,
"CTA17": CTA17,
"CTA18": CTA18,
"CTC20": CTC20,
"CTM18": CTM18,
"CTM23": CTM23,
"CV13": CV13,
"CV25": CV25,
"CV26": CV26,
"FI02S": FI02S,
"FI03S": FI03S,
"FI101S": FI101S,
"FI20S": FI20S,
"FI28S": FI28S,
"FI32S": FI32S,
"FI33S": FI33S,
"FI34S": FI34S,
"FI35S": FI35S,
"FI36SD": FI36SD,
"FR21S": FR21S,
"FR32S": FR32S,
"G020S": G020S,
"G051S": G051S,
"G102S": G102S,
"G106S": G106S,
"G205B": G205B,
"G205S": G205S,
"G210S": G210S,
"G213A": G213A,
"G218A": G218A,
"G225S": G225S,
"G230S": G230S,
"G234S": G234S,
"G250BD": G250BD,
"G250CD": G250CD,
"G300S": G300S,
"G301S": G301S,
"G404S": G404S,
"G405S": G405S,
"G406S": G406S,
"G407S": G407S,
"G416S": G416S,
"G417S": G417S,
"G990S": G990S,
"IN02S": IN02S,
"IN12S": IN12S,
"INAP01": INAP01,
"INST_TRD": INST_TRD,
"JT20S": JT20S,
"JT33S": JT33S,
"JT70S": JT70S,
"LQA232YR": LQA232YR,
"LQR325YR": LQR325YR,
"MT21S": MT21S,
"NOMT_TRD": NOMT_TRD,
"OF09S": OF09S,
"OF20S": OF20S,
"OF21S": OF21S,
"OF29S": OF29S,
"OF35S": OF35S,
"PAYMNT06": PAYMNT06,
"PAYMNT07": PAYMNT07,
"PDMAG01": PDMAG01,
"PER201": PER201,
"PER202": PER202,
"PER203": PER203,
"PER204": PER204,
"PER205": PER205,
"PER222": PER222,
"PER223": PER223,
"PER224": PER224,
"PER225": PER225,
"PER235": PER235,
"PER253": PER253,
"RE32S": RE32S,
"RE36SD": RE36SD,
"RET201": RET201,
"RLE902": RLE902,
"RLE907": RLE907,
"RT20S": RT20S,
"RT36S": RT36S,
"RTL_TRD": RTL_TRD,
"RVDEXQ2": RVDEXQ2,
"RVLR14": RVLR14,
"S204S": S204S,
"SC20S": SC20S,
"SC21S": SC21S,
"SCBALM01": SCBALM01,
"SCC92": SCC92,
"score_results": score_results,
"SE02S": SE02S,
"SE06S": SE06S,
"SE09S": SE09S,
"SE20S": SE20S,
"SE21CD": SE21CD,
"SE21S": SE21S,
"SE34S": SE34S,
"SE36S": SE36S,
"SE36SD": SE36SD,
"ST01S": ST01S,
"TRV03": TRV03,
"TRV04": TRV04,
"TRV06": TRV06,
"TRV10": TRV10,
"US01S": US01S,
"US02S": US02S,
"US03S": US03S,
"US101S": US101S,
"US12S": US12S,
"US20S": US20S,
"US24S": US24S,
"US28S": US28S,
"US30S": US30S,
"US32S": US32S,
"US34S": US34S,
"US35S": US35S,
"US36S": US36S,
"US36SD": US36SD,
"US51A": US51A,
"UTLMAG01": UTLMAG01
}
_ensure_driver_map()
return {"results": [
{"model_a_features": _process_model(
record,
DRIVER_MAP["model_a"],
MODEL_OUTPUT_FEATURE_ALLOWLIST["model_a_features"],
),
"model_b_features": _process_model(
record,
DRIVER_MAP["model_b"],
MODEL_OUTPUT_FEATURE_ALLOWLIST["model_b_features"],
),
"model_t_features": _process_model(
record,
DRIVER_MAP["model_t"],
MODEL_OUTPUT_FEATURE_ALLOWLIST["model_t_features"]
)}]
}