import csv import math from dataclasses import dataclass from enum import Enum from pathlib import Path from typing import Any, Dict, Optional, Sequence, Set class TreatmentType(str, Enum): """Supported treatment rules from the driver dictionaries.""" UNK = "unk" FIXED = "fixed" @dataclass(frozen=True) class VariableRule: """Normalised representation of a single driver row.""" var_name: str data_type: str valid_min: Optional[float] valid_max: Optional[float] observed_cap_min_value: Optional[float] observed_cap_max_value: Optional[float] observed_default_treatment_value: Optional[str] observed_missing_treatment_value: Optional[str] default_treatment_type: Optional[TreatmentType] is_predictor: bool DRIVER_DIR = Path(__file__).parent DRIVER_MAP: Dict[str, Dict[str, VariableRule]] = {} MODEL_FEATURE_ALLOWLIST = { "model_a": ( "AEPMAG05", "RET201", "PER201", "PER202", "PER222", "PER225", "PER235", "CTM18", "SC20S", "AT36SD", "FI36SD", "G250BD", "G250CD", "US36SD", "CV13", "CV25", "CV26", "AT01S", "AT104S", "FI02S", "FI20S", "FI35S", "G051S", "G205S", "G210S", "G218A", "G225S", "G230S", "G234S", "G300S", "IN02S", "IN12S", "OF20S", "RT20S", "INAP01", "G106S", "US02S", "US20S", "US24S", "US28S", "US32S", "US35S", "US36S", "SE20S", "US51A", "G205B", "INST_TRD", "RTL_TRD", "AGG402", "AGG403", "AGG423", "AGG424", "AGG903", "TRV03", "TRV04", "BALMAG01", "score_results", ), "model_b": ( "UTLMAG01", "AEPMAG04", "PER201", "PER203", "PER222", "PER223", "PER224", "PER225", "PER235", "CTM23", "CT321", "CTC20", "CTA17", "CTA18", "SC21S", "SCC92", "SCBALM01", "AT36SD", "FI36SD", "RE36SD", "SE36SD", "US36SD", "LQA232YR", "LQR325YR", "RLE902", "CV25", "CV26", "RVDEXQ2", "AT01S", "AT104S", "AU20S", "BI21S", "BR33S", "CO06S", "FI02S", "FI03S", "FI20S", "FI32S", "FI33S", "FI34S", "FI35S", "FI101S", "FR21S", "FR32S", "G020S", "G102S", "G205S", "G210S", "G213A", "G225S", "G234S", "G301S", "G990S", "IN02S", "IN12S", "MT21S", "OF09S", "OF21S", "OF29S", "OF35S", "RE32S", "RT36S", "ST01S", "INAP01", "G106S", "S204S", "US02S", "US03S", "US12S", "US20S", "US24S", "US30S", "US34S", "SE20S", "SE21S", "SE34S", "SE36S", "JT20S", "JT33S", "JT70S", "G404S", "G405S", "G406S", "G407S", "G416S", "G417S", "US51A", "INST_TRD", "NOMT_TRD", "AGG512", "AGG516", "AGG902", "AGG903", "TRV03", "TRV04", "TRV06", "BALMAG01", "RVLR14", "PAYMNT06", "PAYMNT07", "score_results", ), "model_t": ( "PDMAG01", "AEPMAG05", "AUT201", "PER201", "PER203", "PER204", "PER205", "PER223", "PER225", "PER253", "CTA17", "SE21CD", "RLE907", "CV26", "AT35B", "FI28S", "FI32S", "INAP01", "US01S", "US28S", "US34S", "US101S", "SE02S", "SE06S", "SE09S", "SE20S", "TRV06", "TRV10", "PAYMNT06", ), } MODEL_OUTPUT_FEATURE_ALLOWLIST = { "model_a_features": MODEL_FEATURE_ALLOWLIST["model_a"] + ("PER201_unk", "G225S_unk", "SC20S_unk", "RET201_unk", "US24S_unk"), "model_b_features": MODEL_FEATURE_ALLOWLIST["model_b"], "model_t_features": MODEL_FEATURE_ALLOWLIST["model_t"] + ("AEPMAG05_unk", "PER201_unk"), } def _to_float(candidate: Optional[str]) -> Optional[float]: if candidate in (None, "", "null", "None"): return None try: return float(candidate) except (TypeError, ValueError): return None def _to_bool(candidate: Optional[str]) -> bool: return str(candidate).strip() in {"1", "1.0", "true", "True"} def _normalise_data_type(data_type: str) -> str: if not data_type: return "string" data_type = data_type.strip().lower() if data_type in {"float", "double"}: return "float" if data_type in {"int", "integer"}: return "int" return "string" def _load_driver(name: str, allowed_features: Optional[Set[str]] = None) -> Dict[str, VariableRule]: path = DRIVER_DIR / name with path.open(newline="", encoding="utf-8") as csv_file: reader = csv.DictReader(csv_file) metadata: Dict[str, VariableRule] = {} for row in reader: if not _to_bool(row.get("is_predictor")): continue var_name = row["var_name"] if allowed_features is not None and var_name not in allowed_features: continue treatment_type_raw = (row.get("default_treatment_type") or "").strip().lower() treatment_type: Optional[TreatmentType] if treatment_type_raw: treatment_type = TreatmentType(treatment_type_raw) else: treatment_type = None rule = VariableRule( var_name=var_name, data_type=_normalise_data_type(row.get("data_type", "")), valid_min=_to_float(row.get("valid_min")), valid_max=_to_float(row.get("valid_max")), observed_cap_min_value=_to_float(row.get("observed_cap_min_value")), observed_cap_max_value=_to_float(row.get("observed_cap_max_value")), observed_default_treatment_value=row.get("observed_default_treatment_value") or None, observed_missing_treatment_value=row.get("observed_missing_treatment_value") or None, default_treatment_type=treatment_type, is_predictor=True, ) metadata[rule.var_name] = rule return metadata def _ensure_driver_map() -> None: if DRIVER_MAP: return DRIVER_MAP["model_a"] = _load_driver("data_dictionary_updated_A.csv", set(MODEL_FEATURE_ALLOWLIST["model_a"])) DRIVER_MAP["model_b"] = _load_driver("data_dictionary_updated_B.csv", set(MODEL_FEATURE_ALLOWLIST["model_b"])) DRIVER_MAP["model_t"] = _load_driver("data_dictionary_updated_T.csv", set(MODEL_FEATURE_ALLOWLIST["model_t"])) def _is_number(value: Any) -> bool: return isinstance(value, (int, float)) and not isinstance(value, bool) def _is_missing(value: Any) -> bool: if value is None: return True if isinstance(value, float) and math.isnan(value): return True if isinstance(value, str) and not value.strip(): return True return False def _safe_cast(value: Any, data_type: str) -> Any: if value in (None, "", "null", "NULL"): return None try: if data_type == "int": if isinstance(value, str): value = value.strip() return int(float(value)) if data_type == "float": return float(value) return str(value) except (TypeError, ValueError): return None def _value_out_of_range(value: Any, minimum: Optional[float], maximum: Optional[float]) -> bool: if not _is_number(value): return False if minimum is not None and value < minimum: return True if maximum is not None and value > maximum: return True return False def _process_rule( raw_record: Dict[str, Any], processed_record: Dict[str, Any], rule: VariableRule, lookup: Dict[str, VariableRule], ) -> None: if rule.var_name.endswith("_unk"): # Engineered flag is handled when the base variable is processed. return value = _safe_cast(raw_record.get(rule.var_name), rule.data_type) out_of_range = _value_out_of_range(value, rule.valid_min, rule.valid_max) null_due_to_unk_outlier = False is_missing_value = _is_missing(value) if rule.default_treatment_type == TreatmentType.UNK: unk_name = f"{rule.var_name}_unk" if unk_name not in processed_record: processed_record[unk_name] = 0 if out_of_range or is_missing_value: processed_record[unk_name] = 1 value = None null_due_to_unk_outlier = True else: processed_record[unk_name] = 0 elif rule.default_treatment_type == TreatmentType.FIXED and out_of_range: value = None if _is_number(value): if rule.observed_cap_min_value is not None and value < rule.observed_cap_min_value: value = rule.observed_cap_min_value if rule.observed_cap_max_value is not None and value > rule.observed_cap_max_value: value = rule.observed_cap_max_value if not null_due_to_unk_outlier and _value_out_of_range(value, rule.valid_min, rule.valid_max): default_value = rule.observed_default_treatment_value if default_value is not None: value = _safe_cast(default_value, rule.data_type) if not null_due_to_unk_outlier and _is_missing(value): missing_value = rule.observed_missing_treatment_value if missing_value is not None: value = _safe_cast(missing_value, rule.data_type) processed_record[rule.var_name] = value def _process_model( record: Dict[str, Any], lookup: Dict[str, VariableRule], allowed_features: Optional[Sequence[str]] = None, ) -> Dict[str, Any]: processed: Dict[str, Any] = {} for rule in lookup.values(): _process_rule(record, processed, rule, lookup) # Ensure engineered _unk fields default to 0 when not explicitly set. for var_name, rule in lookup.items(): if var_name.endswith("_unk") and var_name not in processed and rule.is_predictor: processed[var_name] = 0 if allowed_features is not None: processed = {key: processed.get(key) for key in allowed_features} return processed def __main__( AEPMAG04: int, AEPMAG05: int, AGG402: int, AGG403: int, AGG423: int, AGG424: int, AGG512: int, AGG516: int, AGG902: int, AGG903: int, AT01S: int, AT104S: int, AT35B: int, AT36SD: int, AU20S: int, AUT201: float, BALMAG01: int, BI21S: int, BR33S: int, CO06S: int, CT321: int, CTA17: int, CTA18: int, CTC20: int, CTM18: int, CTM23: int, CV13: float, CV25: float, CV26: float, FI02S: int, FI03S: int, FI101S: int, FI20S: int, FI28S: int, FI32S: int, FI33S: int, FI34S: int, FI35S: int, FI36SD: int, FR21S: int, FR32S: int, G020S: int, G051S: int, G102S: int, G106S: int, G205B: int, G205S: int, G210S: int, G213A: int, G218A: int, G225S: int, G230S: int, G234S: int, G250BD: int, G250CD: int, G300S: int, G301S: int, G404S: int, G405S: int, G406S: int, G407S: int, G416S: int, G417S: int, G990S: int, IN02S: int, IN12S: int, INAP01: int, INST_TRD: int, JT20S: int, JT33S: int, JT70S: int, LQA232YR: float, LQR325YR: float, MT21S: int, NOMT_TRD: int, OF09S: int, OF20S: int, OF21S: int, OF29S: int, OF35S: int, PAYMNT06: float, PAYMNT07: float, PDMAG01: int, PER201: float, PER202: float, PER203: float, PER204: float, PER205: float, PER222: float, PER223: float, PER224: float, PER225: float, PER235: int, PER253: int, RE32S: int, RE36SD: int, RET201: float, RLE902: int, RLE907: int, RT20S: int, RT36S: int, RTL_TRD: int, RVDEXQ2: int, RVLR14: str, S204S: int, SC20S: int, SC21S: int, SCBALM01: int, SCC92: int, score_results: int, SE02S: int, SE06S: int, SE09S: int, SE20S: int, SE21CD: int, SE21S: int, SE34S: int, SE36S: int, SE36SD: int, ST01S: int, TRV03: int, TRV04: int, TRV06: int, TRV10: int, US01S: int, US02S: int, US03S: int, US101S: int, US12S: int, US20S: int, US24S: int, US28S: int, US30S: float, US32S: int, US34S: int, US35S: int, US36S: int, US36SD: int, US51A: int, UTLMAG01: int, ) -> dict: record = { "AEPMAG04": AEPMAG04, "AEPMAG05": AEPMAG05, "AGG402": AGG402, "AGG403": AGG403, "AGG423": AGG423, "AGG424": AGG424, "AGG512": AGG512, "AGG516": AGG516, "AGG902": AGG902, "AGG903": AGG903, "AT01S": AT01S, "AT104S": AT104S, "AT35B": AT35B, "AT36SD": AT36SD, "AU20S": AU20S, "AUT201": AUT201, "BALMAG01": BALMAG01, "BI21S": BI21S, "BR33S": BR33S, "CO06S": CO06S, "CT321": CT321, "CTA17": CTA17, "CTA18": CTA18, "CTC20": CTC20, "CTM18": CTM18, "CTM23": CTM23, "CV13": CV13, "CV25": CV25, "CV26": CV26, "FI02S": FI02S, "FI03S": FI03S, "FI101S": FI101S, "FI20S": FI20S, "FI28S": FI28S, "FI32S": FI32S, "FI33S": FI33S, "FI34S": FI34S, "FI35S": FI35S, "FI36SD": FI36SD, "FR21S": FR21S, "FR32S": FR32S, "G020S": G020S, "G051S": G051S, "G102S": G102S, "G106S": G106S, "G205B": G205B, "G205S": G205S, "G210S": G210S, "G213A": G213A, "G218A": G218A, "G225S": G225S, "G230S": G230S, "G234S": G234S, "G250BD": G250BD, "G250CD": G250CD, "G300S": G300S, "G301S": G301S, "G404S": G404S, "G405S": G405S, "G406S": G406S, "G407S": G407S, "G416S": G416S, "G417S": G417S, "G990S": G990S, "IN02S": IN02S, "IN12S": IN12S, "INAP01": INAP01, "INST_TRD": INST_TRD, "JT20S": JT20S, "JT33S": JT33S, "JT70S": JT70S, "LQA232YR": LQA232YR, "LQR325YR": LQR325YR, "MT21S": MT21S, "NOMT_TRD": NOMT_TRD, "OF09S": OF09S, "OF20S": OF20S, "OF21S": OF21S, "OF29S": OF29S, "OF35S": OF35S, "PAYMNT06": PAYMNT06, "PAYMNT07": PAYMNT07, "PDMAG01": PDMAG01, "PER201": PER201, "PER202": PER202, "PER203": PER203, "PER204": PER204, "PER205": PER205, "PER222": PER222, "PER223": PER223, "PER224": PER224, "PER225": PER225, "PER235": PER235, "PER253": PER253, "RE32S": RE32S, "RE36SD": RE36SD, "RET201": RET201, "RLE902": RLE902, "RLE907": RLE907, "RT20S": RT20S, "RT36S": RT36S, "RTL_TRD": RTL_TRD, "RVDEXQ2": RVDEXQ2, "RVLR14": RVLR14, "S204S": S204S, "SC20S": SC20S, "SC21S": SC21S, "SCBALM01": SCBALM01, "SCC92": SCC92, "score_results": score_results, "SE02S": SE02S, "SE06S": SE06S, "SE09S": SE09S, "SE20S": SE20S, "SE21CD": SE21CD, "SE21S": SE21S, "SE34S": SE34S, "SE36S": SE36S, "SE36SD": SE36SD, "ST01S": ST01S, "TRV03": TRV03, "TRV04": TRV04, "TRV06": TRV06, "TRV10": TRV10, "US01S": US01S, "US02S": US02S, "US03S": US03S, "US101S": US101S, "US12S": US12S, "US20S": US20S, "US24S": US24S, "US28S": US28S, "US30S": US30S, "US32S": US32S, "US34S": US34S, "US35S": US35S, "US36S": US36S, "US36SD": US36SD, "US51A": US51A, "UTLMAG01": UTLMAG01 } _ensure_driver_map() return {"results": [ {"model_a_features": _process_model( record, DRIVER_MAP["model_a"], MODEL_OUTPUT_FEATURE_ALLOWLIST["model_a_features"], ), "model_b_features": _process_model( record, DRIVER_MAP["model_b"], MODEL_OUTPUT_FEATURE_ALLOWLIST["model_b_features"], ), "model_t_features": _process_model( record, DRIVER_MAP["model_t"], MODEL_OUTPUT_FEATURE_ALLOWLIST["model_t_features"] )}] }