"""NTD (National Transit Database) data loading and agency matching utilities.
This module provides functions to:
- Load and filter the NTD facility inventory to bus depot locations.
- Fuzzy-match a GTFS agency name and geographic centroid to the NTD agency
table, combining name similarity (rapidfuzz + IDF-weighted token coverage)
with geographic proximity.
"""
import re
from pathlib import Path
from typing import TypedDict
import geopandas as gpd
import numpy as np
import pandas as pd
from geopy.distance import geodesic
from rapidfuzz.fuzz import WRatio
# ---------------------------------------------------------------------------
# NTD facility inventory constants
# ---------------------------------------------------------------------------
# Facility types that represent bus depots, in priority order (highest first).
# Used to pre-filter and rank NTD facility records before distance matching.
_DEPOT_FACILITY_TYPES: list[str] = [
"General Purpose Maintenance Facility/Depot",
"Combined Administrative and Maintenance Facility (describe in Notes)",
"Maintenance Facility (Service and Inspection)",
]
# NTD Primary Mode codes considered "bus-operating".
# DR (demand response) and VP (vanpool) are included because many small bus
# agencies report exclusively under these modes.
_BUS_MODES: frozenset[str] = frozenset({"MB", "RB", "CB", "TB", "PB", "DR", "VP"})
# ---------------------------------------------------------------------------
# NTD agency table constants
# ---------------------------------------------------------------------------
# NTD agency table columns used for matching
_NTD_ID_COL = "NTD_ID"
_OFFICIAL_NAME_COL = "Agency_Name"
_COMMON_NAME_COL = "Common_Name"
_TOKEN_RE = re.compile(r"[a-z0-9]+")
# Matching thresholds
_NAME_SCORE_THRESHOLD = 20 # minimum rapidfuzz score (0-100) to consider a candidate
_MAX_DISTANCE_KM = (
200 # maximum allowed distance between agency centroid and query location
)
_WRATIO_WEIGHT = 0.5
_IDF_WEIGHT = 0.0
_PROXIMITY_SCALE_KM = 75.0 # exponential decay scale for distance scoring
# ---------------------------------------------------------------------------
# Path helpers
# ---------------------------------------------------------------------------
_FACILITIES_FILENAME = "2024 Facility Inventory_260428_1.xlsx"
_AGENCIES_FILENAME = "NTAD_National_Transit_Map_Agencies.csv"
def _ntd_facilities_path() -> Path:
from routee.transit import ntd_path
return ntd_path() / _FACILITIES_FILENAME
def _ntd_agencies_path() -> Path:
from routee.transit import ntd_path
return ntd_path() / _AGENCIES_FILENAME
# ---------------------------------------------------------------------------
# Data loaders
# ---------------------------------------------------------------------------
[docs]
def load_ntd_facilities(
ntd_id: str | None = None,
ntd_ids: list[str] | None = None,
) -> gpd.GeoDataFrame:
"""Load and filter the NTD facility inventory to bus depot locations.
Reads the bundled NTD "Facility Inventory" xlsx, retains only rows that:
1. Belong to a bus-operating agency (``Primary Mode Served`` in
``{MB, RB, CB, TB, PB, DR, VP}``).
2. Are one of the three depot facility types (general purpose depot,
combined admin/maintenance, or service-and-inspection facility).
3. Have valid latitude/longitude coordinates.
Pass ``ntd_id`` to restrict to a single agency or ``ntd_ids`` for several.
When both are omitted all bus depot facilities across all agencies are
returned. Passing both is an error.
Parameters
----------
ntd_id : str | None
Zero-padded 5-digit NTD ID (e.g. ``"00001"``). Mutually exclusive
with ``ntd_ids``.
ntd_ids : list[str] | None
List of zero-padded 5-digit NTD IDs. Facilities for all listed
agencies are returned combined. Mutually exclusive with ``ntd_id``.
Returns
-------
gpd.GeoDataFrame
Point GeoDataFrame in EPSG:4326 with columns including ``NTD ID``,
``Agency Name``, ``Facility Type``, ``Facility Name``, and a
``depot_priority`` column (0 = highest priority depot type).
"""
if ntd_id is not None and ntd_ids is not None:
raise ValueError("Pass either ntd_id or ntd_ids, not both.")
if ntd_id is not None:
ntd_ids = [ntd_id]
path = _ntd_facilities_path()
df = pd.read_excel(path, dtype={"NTD ID": str})
# Normalise NTD ID to zero-padded 5-digit string
df["NTD ID"] = df["NTD ID"].str.zfill(5)
# Filter to bus-operating modes
df = df[df["Primary Mode Served"].isin(_BUS_MODES)]
# Filter to depot facility types only
df = df[df["Facility Type"].isin(_DEPOT_FACILITY_TYPES)]
# Filter to valid coordinates
df = df[df["Latitude"].notna() & df["Longitude"].notna()]
# Attach priority rank (lower = better)
priority_map = {ft: i for i, ft in enumerate(_DEPOT_FACILITY_TYPES)}
df = df.copy()
df["depot_priority"] = df["Facility Type"].map(priority_map)
if ntd_ids is not None:
normalised = [nid.zfill(5) for nid in ntd_ids]
df = df[df["NTD ID"].isin(normalised)]
if df.empty:
id_desc = f" for NTD ID(s) {ntd_ids!r}" if ntd_ids is not None else ""
raise ValueError(f"No bus depot facilities found in NTD inventory{id_desc}.")
gdf = gpd.GeoDataFrame(
df.reset_index(drop=True),
geometry=gpd.points_from_xy(df["Longitude"], df["Latitude"]),
crs="EPSG:4326",
)
return gdf
def _load_ntd_agencies(bus_only: bool = False) -> pd.DataFrame:
"""Load the bundled NTD agency table.
Parameters
----------
bus_only:
When ``True``, restrict to agencies that operate at least one bus mode
according to the NTD facility inventory. When ``False`` (default), all
agencies are returned — name + proximity scoring is sufficient to avoid
rail-only mismatches in practice.
"""
agencies = pd.read_csv(_ntd_agencies_path(), dtype={_NTD_ID_COL: str})
if not bus_only:
return agencies
# Derive the set of NTD IDs that operate bus modes from the facility xlsx.
fac_path = _ntd_facilities_path()
fac = pd.read_excel(
fac_path, usecols=["NTD ID", "Primary Mode Served"], dtype={"NTD ID": str}
)
fac["NTD ID"] = fac["NTD ID"].str.zfill(5)
bus_ntd_ids: set[str] = set(
fac.loc[fac["Primary Mode Served"].isin(_BUS_MODES), "NTD ID"].unique()
)
return agencies[agencies[_NTD_ID_COL].isin(bus_ntd_ids)].reset_index(drop=True)
# ---------------------------------------------------------------------------
# Fuzzy matching internals
# ---------------------------------------------------------------------------
def _tokenize_name(name: str) -> set[str]:
"""Tokenize a name into lowercase alphanumeric tokens."""
return set(_TOKEN_RE.findall(name.casefold()))
def _compute_token_idf(agencies: pd.DataFrame) -> dict[str, float]:
"""Compute IDF-like token weights from official/common agency names."""
token_document_counts: dict[str, int] = {}
for _, row in agencies.iterrows():
official_name = str(row.get(_OFFICIAL_NAME_COL, ""))
common_name = str(row.get(_COMMON_NAME_COL, ""))
document_tokens = _tokenize_name(official_name) | _tokenize_name(common_name)
for token in document_tokens:
token_document_counts[token] = token_document_counts.get(token, 0) + 1
n_documents = len(agencies)
return {
token: float(np.log((n_documents + 1) / (count + 1)) + 1.0)
for token, count in token_document_counts.items()
}
def _idf_query_coverage_score(
query_tokens: set[str], candidate_name: str, token_idf: dict[str, float]
) -> float:
"""Score candidate by weighted coverage of query tokens (0-100)."""
if not query_tokens:
return 0.0
candidate_tokens = _tokenize_name(candidate_name)
if not candidate_tokens:
return 0.0
denominator = sum(token_idf.get(token, 1.0) for token in query_tokens)
if denominator == 0:
return 0.0
numerator = sum(
token_idf.get(token, 1.0) for token in query_tokens & candidate_tokens
)
return 100.0 * numerator / denominator
# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------
[docs]
class NTDAgencyMatch(TypedDict):
"""Row from the NTD agency table returned by :func:`match_agency_to_ntd`."""
OBJECTID: int
NTD_ID: str
NTAD_NTD_ID: str
Agency_Name: str
Common_Name: str
City: str
State: str
Lon: float
Lat: float
In_Latest_NTAD_Upload: str
License: str
x: float
y: float
_name_score: int
_distance_km: float
[docs]
def match_agency_to_ntd(
agency_name: str,
lat: float,
lon: float,
name_threshold: int = _NAME_SCORE_THRESHOLD,
max_distance_km: float = _MAX_DISTANCE_KM,
agency_id: str | None = None,
) -> NTDAgencyMatch:
"""Fuzzy-match an agency name and location to a row in the NTD agency table.
Candidates are scored by a weighted combination of name similarity and
geographic distance. Name scoring blends:
1. Rapidfuzz ``WRatio``
2. IDF-weighted query-token coverage, where common tokens across NTD
agencies (e.g. "transit", "city") have less influence than rare tokens.
Both the official legal name (``Agency_Name``) and common name
(``Common_Name``) are considered; the higher of the two scores is used.
Proximity uses exponential decay (``exp(-dist / scale)``) so that very
close matches (< 10 km) are strongly preferred over distant ones.
If ``agency_id`` is provided and, after zero-padding to 5 digits, exactly
matches a candidate's NTD ID, that candidate receives a large bonus to
the combined score.
Parameters
----------
agency_name : str
Agency name to match (e.g. from GTFS ``agency.txt`` or Mobility Database).
lat : float
Approximate latitude of the agency's service area (WGS84).
lon : float
Approximate longitude of the agency's service area (WGS84).
name_threshold : int
Minimum rapidfuzz ``WRatio`` score (0–100) for a candidate to
be considered. Candidates below this threshold are discarded before
distance scoring.
max_distance_km : float
If the best candidate's centroid is farther than this from ``(lat, lon)``,
a ``ValueError`` is raised even if the name score is high.
agency_id : str | None
Optional GTFS ``agency_id``. When it zero-pads to a valid 5-digit NTD
ID, the matching candidate gets a strong bonus.
Returns
-------
NTDAgencyMatch
Row from the NTD agency table for the best match.
Includes all original columns plus ``_name_score`` and
``_distance_km``.
Raises
------
ValueError
If no candidate passes the name threshold, or if the best candidate
exceeds ``max_distance_km``.
"""
agencies = _load_ntd_agencies()
query_tokens = _tokenize_name(agency_name)
token_idf = _compute_token_idf(agencies)
# Score both name columns; take the higher of the two for each row.
# WRatio handles partial/abbreviated matches robustly.
official_scores = np.array(
[WRatio(agency_name, name) for name in agencies[_OFFICIAL_NAME_COL].fillna("")]
)
common_scores = np.array(
[WRatio(agency_name, name) for name in agencies[_COMMON_NAME_COL].fillna("")]
)
wratio_scores = np.maximum(official_scores, common_scores)
# IDF coverage downweights generic words that appear in many agencies.
official_idf_scores = np.array(
[
_idf_query_coverage_score(query_tokens, name, token_idf)
for name in agencies[_OFFICIAL_NAME_COL].fillna("")
]
)
common_idf_scores = np.array(
[
_idf_query_coverage_score(query_tokens, name, token_idf)
for name in agencies[_COMMON_NAME_COL].fillna("")
]
)
idf_scores = np.maximum(official_idf_scores, common_idf_scores)
name_scores = (_WRATIO_WEIGHT * wratio_scores) + (_IDF_WEIGHT * idf_scores)
# Pre-filter by name threshold to avoid unnecessary distance calculations
mask = name_scores >= name_threshold
if not mask.any():
best_name = agencies.loc[name_scores.argmax(), _OFFICIAL_NAME_COL]
best_score = int(name_scores.max())
raise ValueError(
f"No NTD agency matched '{agency_name}' above the name threshold "
f"of {name_threshold}. Best candidate was '{best_name}' "
f"(score={best_score}). Try lowering name_threshold or check the "
f"agency name spelling."
)
candidates = agencies[mask].copy()
candidate_name_scores = name_scores[mask]
# Compute great-circle distance from query point to each candidate centroid
query_point = (lat, lon)
distances_km = np.array(
[
geodesic(query_point, (row["Lat"], row["Lon"])).km
for _, row in candidates.iterrows()
]
)
# Combined score: name similarity (weighted 0.5) + proximity bonus (0.2).
# Proximity uses exponential decay: nearby candidates (< 10 km) score ~1.0,
# while distant ones decay smoothly toward 0.
proximity = np.exp(-distances_km / _PROXIMITY_SCALE_KM)
combined = 0.5 * (candidate_name_scores / 100.0) + 0.2 * proximity
# Agency ID bonus: if the GTFS agency_id zero-pads to an NTD ID, boost
# that candidate so it wins when name + proximity are even close.
if agency_id is not None:
try:
padded_id = str(int(agency_id)).zfill(5)
except (ValueError, TypeError):
padded_id = None
if padded_id is not None:
id_match_mask = candidates[_NTD_ID_COL].values == padded_id
combined[id_match_mask] += 0.3
best_pos = int(combined.argmax())
best_distance_km = float(distances_km[best_pos])
if best_distance_km > max_distance_km:
best_row = candidates.iloc[best_pos]
raise ValueError(
f"Best NTD match for '{agency_name}' is '{best_row[_OFFICIAL_NAME_COL]}' "
f"({best_row[_COMMON_NAME_COL]}), but its centroid is "
f"{best_distance_km:.1f} km away — exceeds max_distance_km={max_distance_km}. "
f"Verify the lat/lon or increase max_distance_km."
)
row = candidates.iloc[best_pos]
return NTDAgencyMatch(
OBJECTID=int(row["OBJECTID"]),
NTD_ID=str(row["NTD_ID"]),
NTAD_NTD_ID=str(row["NTAD_NTD_ID"]),
Agency_Name=str(row["Agency_Name"]),
Common_Name=str(row["Common_Name"]),
City=str(row["City"]),
State=str(row["State"]),
Lon=float(row["Lon"]),
Lat=float(row["Lat"]),
In_Latest_NTAD_Upload=str(row["In_Latest_NTAD_Upload"]),
License=str(row["License"]),
x=float(row["x"]),
y=float(row["y"]),
_name_score=int(candidate_name_scores[best_pos]),
_distance_km=round(best_distance_km, 2),
)