Source code for demos.models.rebalancing

# fmt: off
import orca
import numpy as np
import pandas as pd
from templates.utils.models import columns_in_formula
from templates import estimated_models, modelmanager as mm
from templates.utils import transition
from templates.utils.transition import GrowthRateTransition
from config import DEMOSConfig, HHRebalancingModuleConfig, SimultaneousCalibrationConfig, get_config

import time
from logging_logic import log_execution_time
from loguru import logger

STEP_NAME = "household_rebalancing"

[docs] @orca.step(STEP_NAME) def household_rebalancing(households, persons, year, get_new_households, get_new_person_id, rebalanced_households, rebalanced_persons): """ Adjust household counts to match control totals by geography and household size. This step compares current household counts with control totals and duplicates or removes households as needed. It maintains population consistency by also updating the persons table and stores removed records for tracking purposes. Parameters ---------- households : orca.Table The households table containing household-level attributes. persons : orca.Table The persons table containing individual-level attributes. year : int The current simulation year. get_new_households : callable Function to generate new unique household IDs. get_new_person_id : callable Function to generate new unique person IDs. rebalanced_households : orca.Table Table for storing removed household records. rebalanced_persons : orca.Table Table for storing removed person records. Returns ------- None Notes ----- - Modifies households and persons tables in place by adding/removing records. - Uses module configuration to determine control table and column mappings. - Tracks marital status before and after operations in marital_rebalanced table. - Skips processing if no control data exists for the current year. - Sampling with replacement occurs when duplicating more households than available. """ start_time = time.time() # Load calibration config demos_config: DEMOSConfig = get_config() module_config: HHRebalancingModuleConfig = demos_config.hh_rebalancing_module_config marital_rebalanced = orca.get_table("marital_rebalanced") marital_rebalanced.local = pd.concat([marital_rebalanced.local, pd.DataFrame([[year, (orca.get_table("persons").local.MAR == 1).sum(), (orca.get_table("persons").local.MAR == 3).sum()]], columns=["year", "married_original", "divorced_original"])]) CONTROL_TABLE = module_config.control_table GEOID_COL = module_config.geoid_col CONTROL_COL = module_config.control_col control_table_wrapped = orca.get_table(CONTROL_TABLE) assert GEOID_COL in control_table_wrapped.local_columns, f"{GEOID_COL} must be in {CONTROL_TABLE}" assert CONTROL_COL in control_table_wrapped.local_columns, f"{CONTROL_COL} must be in {CONTROL_TABLE}" assert control_table_wrapped.index.name == "year", f"The index of {CONTROL_TABLE} must be 'year'" assert len(control_table_wrapped.local_columns) == 3, f"{CONTROL_TABLE} needs to have exactly 3 columns: {GEOID_COL}, {CONTROL_COL} and the value column" assert persons.household_id.nunique() == households.index.nunique(), f"`persons` and `households` tables do not have coherent sizes. {persons.household_id.nunique()} vs. {households.index.nunique()}" if year not in control_table_wrapped.local.index: return value_column = [c for c in control_table_wrapped.local_columns if c not in [GEOID_COL, CONTROL_COL]][0] index_df = households.to_frame([GEOID_COL, CONTROL_COL]).sort_values([GEOID_COL, CONTROL_COL]) indices = index_df.groupby([GEOID_COL, CONTROL_COL]).indices current_count = index_df.groupby([GEOID_COL, CONTROL_COL]).size() hh_difference = control_table_wrapped.local.loc[year].astype({GEOID_COL: "str", CONTROL_COL:"str"}).set_index([GEOID_COL, CONTROL_COL])[value_column].loc[current_count.index] - current_count to_remove_hh = [] to_duplicate_hh = [] for (geo_id, hh_size), adjustment in hh_difference.items(): valid_indices = index_df.index[indices[(geo_id, hh_size)]] selected_hh = np.random.choice(valid_indices, size=abs(adjustment), replace=(adjustment > 0) and (abs(adjustment) > len(valid_indices))).tolist() if adjustment < 0: to_remove_hh += selected_hh if adjustment > 0: to_duplicate_hh += selected_hh logger.debug(f"Number of households to duplicate: {len(to_duplicate_hh)}") logger.debug(f"Number of households to remove: {len(to_remove_hh)}") # Duplicate the households accordingly ## We duplicate first to reduce the chances of a household_id collision to_duplicate_hh.sort() new_hh_ids = get_new_households(len(to_duplicate_hh)) # Remeber that this creates household rows new_hh_rows = households.local.loc[to_duplicate_hh].copy() new_hh_rows.index = new_hh_ids hh_mapping = pd.DataFrame({ "orig_hh": to_duplicate_hh, "new_hh": new_hh_ids }) new_person_rows = persons.local[persons.household_id.isin(to_duplicate_hh)].copy() new_person_rows = new_person_rows.merge(hh_mapping, left_on="household_id", right_on="orig_hh") new_person_rows["household_id"] = new_person_rows["new_hh"] new_person_rows.drop(["orig_hh", "new_hh"], inplace=True, axis=1) new_person_rows.index = get_new_person_id(len(new_person_rows)) households.local.loc[new_hh_ids] = new_hh_rows persons.local = pd.concat([persons.local, new_person_rows]) # Remove the households accordingly to_remove_hh.sort() rebalanced_households.local = pd.concat([rebalanced_households.local, households.local.loc[to_remove_hh]]) rebalanced_persons.local = pd.concat([rebalanced_persons.local, persons.local[persons.household_id.isin(to_remove_hh)]]) persons.local = persons.local[~persons.household_id.isin(to_remove_hh)] households.local = households.local[~households.index.isin(to_remove_hh)] log_execution_time(start_time, orca.get_injectable("year"), "rebalancing") marital_rebalanced = orca.get_table("marital_rebalanced") marital_rebalanced.local = pd.concat([marital_rebalanced.local, pd.DataFrame([[year, (orca.get_table("persons").local.MAR == 1).sum(), (orca.get_table("persons").local.MAR == 3).sum()]], columns=["year", "married_after", "divorced_after"])])