"""
Geospatial plotting utilities for geoPFA.
This module provides both 2D and 3D plotting functions for geospatial
GeoDataFrames used throughout the geoPFA workflow.
"""
from __future__ import annotations
import numpy as np
import matplotlib.pyplot as plt
from scipy.interpolate import griddata
import contextily as ctx
from mpl_toolkits.mplot3d import Axes3D
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
from matplotlib.gridspec import GridSpec
from matplotlib.ticker import MaxNLocator
import shapely
from shapely.ops import linemerge
from contextlib import suppress
[docs]
class GeospatialDataPlotters:
"""
Collection of static plotting utilities for geoPFA geospatial data.
"""
@staticmethod
def _coords3_from_point(pt):
z = getattr(pt, "z", None)
if z is None:
c0 = pt.coords[0]
z = c0[2] if len(c0) == 3 else 0.0 # noqa: PLR2004
return (pt.x, pt.y, z)
@staticmethod
def _build_well_pts(_well):
if _well is None:
return None
if hasattr(_well, "geometry"):
if all(g.geom_type == "Point" for g in _well.geometry):
return np.array(
[
GeospatialDataPlotters._coords3_from_point(p)
for p in _well.geometry
],
dtype=float,
)
# fallback: lines in a GDF
geoms = list(_well.geometry)
merged = geoms[0]
if len(geoms) > 1:
with suppress(Exception):
merged = linemerge(geoms)
if isinstance(
merged, shapely.LineString | shapely.MultiLineString
):
parts = (
merged.geoms
if isinstance(merged, shapely.MultiLineString)
else [merged]
)
arrs = []
for ls in parts:
arr = np.asarray(ls.coords, dtype=float)
if arr.shape[1] == 2: # noqa: PLR2004
arr = np.c_[arr, np.zeros(len(arr))]
arrs.append(arr)
return np.vstack(arrs) if arrs else None
return None
# plain shapely lines
if isinstance(_well, shapely.LineString | shapely.MultiLineString):
parts = (
_well.geoms
if isinstance(_well, shapely.MultiLineString)
else [_well]
)
arrs = []
for ls in parts:
arr = np.asarray(ls.coords, dtype=float)
if arr.shape[1] == 2: # noqa: PLR2004
arr = np.c_[arr, np.zeros(len(arr))]
arrs.append(arr)
return np.vstack(arrs) if arrs else None
return None
@staticmethod
def geo_plot( # noqa: PLR0913, PLR0917
gdf,
col,
units,
title,
area_outline=None,
overlay=None,
xlabel="default",
ylabel="default",
cmap="jet",
xlim=None,
ylim=None,
extent=None,
basemap=False,
markersize=15,
figsize=(10, 10),
vmin=None,
vmax=None,
):
"""Plots data using gdf.plot(). Preserves geometry, but does not look
smoothe.
Parameters
----------
gdf : pandas geodataframe
Geodataframe containing data to plot, including a geometry column and crs.
col : str
Name of column containing data value to plot, if applicable.
units : str
Units of data to plot.
title : str
Title to add to plot.
area_outline : geodataframe
Optional, Geodataframe contatining outline of area to overlay on plot.
overlay : geodataframe
Optional, Geodataframe containing data locations to plot over map data.
xlabel, ylabel : str
Optional, label for x-axis and y-axis.
cmap : str
Optional, colormap to use instead of the default 'jet'.
xlim, ylim : tuple
Optional, limits to use for x and y axes.
extent : list
List of length 4 containing the extent (i.e., bounding box) to use in
lieau of xlim and ylim, in this order: [x_min, y_min, x_max, y_max].
basemap : bool
Option to add a basemap, defaults to False.
markersize : int
Option to specify marker size to use in plot. Defaults to 15.
figsize : tuple
Option to specify figure size. Defaults to (10,10).
vmin, vmax : float
Optional minimum and maximum values to include in colorbar. If not provided,
will use min and max value of data in the column to plot.
"""
fig, ax = plt.subplots(figsize=figsize)
if col is None or str(col).lower() == "none":
gdf.plot(ax=ax)
else:
if vmin is None:
norm = plt.Normalize(vmin=gdf[col].min(), vmax=gdf[col].max())
else:
norm = plt.Normalize(vmin=vmin, vmax=vmax)
gdf.plot(
ax=ax,
marker="s",
markersize=markersize,
column=col,
cmap=cmap,
norm=norm,
legend=False,
)
sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm)
cbar = fig.colorbar(sm, ax=ax)
cbar.set_label(units)
if area_outline is not None:
area_outline.boundary.plot(ax=ax, color="black")
if overlay is not None:
overlay.plot(ax=ax, color="gray", markersize=3, alpha=0.5)
if xlabel == "default":
xlabel = gdf.crs.axis_info[1].name if gdf.crs else "X-axis"
if ylabel == "default":
ylabel = gdf.crs.axis_info[0].name if gdf.crs else "Y-axis"
if basemap:
ctx.add_basemap(ax)
if xlim is not None:
plt.xlim(xlim)
if ylim is not None:
plt.ylim(ylim)
elif extent is not None:
plt.xlim(extent[0], extent[2])
plt.ylim(extent[1], extent[3])
plt.xlabel(xlabel)
plt.ylabel(ylabel)
plt.title(title)
plt.grid(True)
plt.tight_layout()
plt.show()
@staticmethod
def plot_zoom_in( # noqa: PLR0913, PLR0917
gdf,
col,
units,
title,
xlim,
ylim,
figsize,
markersize,
xlabel,
ylabel,
cmap,
):
"""Method to plot zoomed in version of geopfa maps, using xlim and ylim to determine the extent.
Also adds a basemap."""
fig, ax = plt.subplots(figsize=figsize)
if col is None or str(col).lower() == "none":
gdf.plot(ax=ax)
else:
gdf.plot(
ax=ax,
marker="s",
markersize=markersize,
column=col,
cmap=cmap,
legend=False,
alpha=0.25,
)
sm = plt.cm.ScalarMappable(
cmap=cmap,
norm=plt.Normalize(vmin=gdf[col].min(), vmax=gdf[col].max()),
)
cbar = fig.colorbar(sm, ax=ax)
cbar.set_label(units)
if xlabel == "default":
xlabel = gdf.crs.axis_info[1].name if gdf.crs else "X-axis"
if ylabel == "default":
ylabel = gdf.crs.axis_info[0].name if gdf.crs else "Y-axis"
# TODO: Basemap is causing problems. Fix at a later date.
# Add the basemap
# ctx.add_basemap(ax=ax)
if xlim is not None:
ax.set_xlim(xlim)
if ylim is not None:
ax.set_ylim(ylim)
plt.xlabel(xlabel)
plt.ylabel(ylabel)
plt.title(title)
plt.grid(True)
plt.tight_layout()
plt.show()
@staticmethod
def raster_plot(gdf, col, units, layer):
"""Plots data using pcolormesh. Creates a smoother plot, but does not
preserve geometry in plot"""
x = gdf.geometry.x
y = gdf.geometry.y
z = gdf[col]
# grid coordinates
xi = np.linspace(x.min(), x.max(), 500)
yi = np.linspace(y.min(), y.max(), 500)
xi, yi = np.meshgrid(xi, yi)
# interpolate
zi = griddata((x, y), z, (xi, yi), method="linear")
fig, ax = plt.subplots(figsize=(10, 10))
c = ax.pcolormesh(xi, yi, zi, shading="auto", cmap="jet")
fig.colorbar(c, ax=ax, label=units)
plt.title(f"{layer}: heatmap")
plt.xlabel("easting (m)")
plt.ylabel("northing (m)")
plt.grid(True)
plt.tight_layout()
plt.show()
@staticmethod
def geo_plot_3d( # noqa: PLR0912, PLR0913, PLR0914, PLR0915, PLR0917
gdf,
col,
units,
title,
area_outline=None,
overlay=None,
well_path=None,
well_path_values=None,
xlabel="default",
ylabel="default",
zlabel="Z-axis",
cmap="jet",
xlim=None,
ylim=None,
zlim=None,
extent=None,
markersize=15,
figsize=(12, 10),
vmin=None,
vmax=None,
filter_threshold=None,
x_slice=None,
y_slice=None,
z_slice=None,
# Well-path colorbar settings
well_units="Temperature (°C)",
well_cmap="magma",
show_well_colorbar=True,
well_vmin=None, # independent well-path vmin
well_vmax=None, # independent well-path vmax
# Main (favorability) colorbar settings
show_main_colorbar=True,
# View controls: azimuths approximate "looking toward" each compass direction
view_nw=(20, 135), # from SE looking toward NW
view_ne=(20, 45), # from SW looking toward NE
view_sw=(20, -135), # from NE looking toward SW
view_se=(20, -45), # from NW looking toward SE
# Layout control: relative width of colorbar column
cbar_width=0.08,
):
"""
Plots 3D geospatial data with four directional views (NW, NE, SW, SE) in a 2x2 grid.
The two colorbars live in a separate right-hand column:
- Top: main dataset colorbar (col / units)
- Bottom: well-path colorbar (well_units), if well values exist
Parameters are otherwise identical to your previous version.
"""
# ---------- helpers ----------
def _apply_slice_pts(arr):
if arr is None:
return None
mask = np.ones(len(arr), dtype=bool)
if x_slice is not None:
mask &= arr[:, 0] <= x_slice
if y_slice is not None:
mask &= arr[:, 1] <= y_slice
if z_slice is not None:
mask &= arr[:, 2] <= z_slice
return arr[mask]
# ---------- prep main dataset ----------
gdf_copy = gdf.copy()
# main colormap/norm
if col is not None and str(col).lower() != "none":
vmin_main = gdf_copy[col].min() if vmin is None else vmin
vmax_main = gdf_copy[col].max() if vmax is None else vmax
norm_main = plt.Normalize(vmin=vmin_main, vmax=vmax_main)
cmap_main_obj = plt.get_cmap(cmap)
else:
norm_main = None
cmap_main_obj = None
# slicing on first coordinate (matches your original semantics)
if x_slice is not None:
gdf_copy = gdf_copy[
gdf_copy.geometry.apply(
lambda geom: geom.coords[0][0] <= x_slice
)
]
if y_slice is not None:
gdf_copy = gdf_copy[
gdf_copy.geometry.apply(
lambda geom: geom.coords[0][1] <= y_slice
)
]
if z_slice is not None:
gdf_copy = gdf_copy[
gdf_copy.geometry.apply(
lambda geom: geom.coords[0][2] <= z_slice
)
]
# threshold filter
if filter_threshold is not None and col != "None":
gdf_filtered = gdf_copy[gdf_copy[col] >= filter_threshold]
else:
gdf_filtered = gdf_copy
if gdf_filtered.empty and well_path is None:
print("No data to plot after filtering and slicing.")
return
# color array for points (only created if col provided)
if (
col is not None and str(col).lower() != "none"
) and not gdf_filtered.empty:
filtered_colors = cmap_main_obj(norm_main(gdf_filtered[col]))
else:
filtered_colors = "blue"
# well points + values
well_pts = _apply_slice_pts(
GeospatialDataPlotters._build_well_pts(well_path)
)
well_vals = (
None if well_path_values is None else np.asarray(well_path_values)
)
# do we actually have usable well values?
has_well_values = (
well_pts is not None
and len(well_pts) > 0
and well_vals is not None
and np.isfinite(well_vals).any()
)
# ---------- figure layout: 2x2 views + right colorbar column ----------
# Grid: 2 rows x 3 columns
# [ NW | NE | main_cbar ]
# [ SW | SE | well_cbar ]
fig = plt.figure(figsize=figsize, constrained_layout=True)
gs = GridSpec(2, 3, figure=fig, width_ratios=[1, 1, cbar_width])
ax_nw = fig.add_subplot(gs[0, 0], projection="3d")
ax_ne = fig.add_subplot(gs[0, 1], projection="3d")
ax_sw = fig.add_subplot(gs[1, 0], projection="3d")
ax_se_ax = fig.add_subplot(gs[1, 1], projection="3d")
cax_main = fig.add_subplot(gs[0, 2]) # main colorbar
cax_well = fig.add_subplot(gs[1, 2]) # well colorbar
# make cbar axes frameless but keep ticks/labels visible
for cax in (cax_main, cax_well):
for spine in cax.spines.values():
spine.set_visible(False)
# hide well cbar axis entirely if we know we won't use it
if not (show_well_colorbar and has_well_values):
cax_well.set_axis_off()
# set view angles
ax_nw.view_init(*view_nw)
ax_ne.view_init(*view_ne)
ax_sw.view_init(*view_sw)
ax_se_ax.view_init(*view_se)
# ---------- shared per-panel plotting ----------
def _plot_on( # noqa: PLR0912, PLR0914, PLR0915
ax, add_main_cbar=False, add_well_cbar=False
):
# main geometries
if not gdf_filtered.empty:
gtype0 = gdf_filtered.geometry.iloc[0].geom_type
if gtype0 == "Point":
xs, ys, zs = zip(
*[geom.coords[0] for geom in gdf_filtered.geometry]
)
if isinstance(filtered_colors, str):
ax.scatter(
xs, ys, zs, s=markersize, color=filtered_colors
)
else:
ax.scatter(xs, ys, zs, s=markersize, c=filtered_colors)
elif gtype0 in {"Polygon", "MultiPolygon"}:
for geom in gdf_filtered.geometry:
if geom.geom_type == "Polygon":
rings = [geom.exterior, *list(geom.interiors)]
elif geom.geom_type == "MultiPolygon":
rings = [
ring
for polygon in geom.geoms
for ring in [
polygon.exterior,
*list(polygon.interiors),
]
]
else:
rings = []
for ring in rings:
verts = [
(c[0], c[1], c[2] if len(c) == 3 else 0) # noqa: PLR2004
for c in ring.coords
]
ax.add_collection3d(
Poly3DCollection(
[verts],
alpha=0.5,
edgecolor="grey",
facecolor="lightblue",
)
)
# overlay
if (
overlay is not None
and hasattr(overlay, "empty")
and not overlay.empty
):
ox, oy, oz = zip(
*[geom.coords[0] for geom in overlay.geometry]
)
ax.scatter(ox, oy, oz, color="gray", s=5, alpha=0.5)
# well path scatter
sc_well = None
if well_pts is not None and len(well_pts):
if not has_well_values:
# just draw the well in black if no values
sc_well = ax.scatter(
well_pts[:, 0],
well_pts[:, 1],
well_pts[:, 2],
s=markersize * 1.6,
color="k",
alpha=0.9,
zorder=5,
)
else:
vals = well_vals
if len(vals) > len(well_pts):
vals = vals[: len(well_pts)]
elif len(vals) < len(well_pts):
vals = np.concatenate(
[vals, np.full(len(well_pts) - len(vals), np.nan)]
)
w_cmap = plt.get_cmap(well_cmap)
vmin_w = (
np.nanmin(vals) if well_vmin is None else well_vmin
)
vmax_w = (
np.nanmax(vals) if well_vmax is None else well_vmax
)
norm_w = plt.Normalize(vmin=vmin_w, vmax=vmax_w)
sc_well = ax.scatter(
well_pts[:, 0],
well_pts[:, 1],
well_pts[:, 2],
s=markersize * 1.6,
c=vals,
cmap=w_cmap,
norm=norm_w,
alpha=0.9,
zorder=5,
)
# area outline
if (
area_outline is not None
and hasattr(area_outline, "empty")
and not area_outline.empty
):
if (
not gdf_copy.empty
and gdf_copy.geometry.iloc[0].geom_type == "Point"
):
zmax = max(geom.z for geom in gdf_copy.geometry)
elif not gdf_copy.empty and gdf_copy.geometry.iloc[
0
].geom_type in {"Polygon", "MultiPolygon"}:
zmax = max(
max(
coord[2]
for coord in ring.coords
if len(coord) == 3 # noqa: PLR2004
)
for geom in gdf_copy.geometry
for ring in ([geom.exterior, *list(geom.interiors)])
)
else:
zmax = 0
for poly in area_outline.geometry:
xs, ys = zip(*[(c[0], c[1]) for c in poly.exterior.coords])
zs = [zmax + 1] * len(xs)
ax.plot(xs, ys, zs, color="black")
# labels & limits
xlabel_final = (
xlabel
if xlabel != "default"
else (
gdf_copy.crs.axis_info[1].name
if gdf_copy.crs
else "X-axis"
)
)
ylabel_final = (
ylabel
if ylabel != "default"
else (
gdf_copy.crs.axis_info[0].name
if gdf_copy.crs
else "Y-axis"
)
)
zlabel_final = zlabel or "Z-axis"
ax.set_xlabel(xlabel_final)
ax.set_ylabel(ylabel_final)
ax.set_zlabel(zlabel_final)
if extent is not None and zlim is None:
ax.set_xlim(extent[0], extent[3])
ax.set_ylim(extent[1], extent[4])
ax.set_zlim(extent[2], extent[5])
else:
if xlim is not None:
ax.set_xlim(xlim)
if ylim is not None:
ax.set_ylim(ylim)
if zlim is not None:
ax.set_zlim(zlim)
ax.grid(True)
# main colorbar (only once, into dedicated axis)
if (
add_main_cbar
and (col is not None and str(col).lower() != "none")
and not gdf_filtered.empty
):
sm = plt.cm.ScalarMappable(cmap=cmap_main_obj, norm=norm_main)
cax_main.cla()
cb = plt.colorbar(sm, cax=cax_main)
cb.set_label(units)
cb.locator = MaxNLocator(nbins=6)
cb.update_ticks()
cb.ax.tick_params(labelsize=9)
# well colorbar (only once, if we truly have well values and user wants it)
if (
add_well_cbar
and show_well_colorbar
and has_well_values
and sc_well is not None
):
cax_well.cla()
cbw = plt.colorbar(sc_well, cax=cax_well)
cbw.set_label(well_units)
cbw.locator = MaxNLocator(nbins=6)
cbw.update_ticks()
cbw.ax.tick_params(labelsize=9)
cbw.ax.yaxis.set_ticks_position("left")
cbw.ax.yaxis.set_label_position("left")
# plot all panels; add colorbars only once (NW panel)
_plot_on(ax_nw, add_main_cbar=show_main_colorbar, add_well_cbar=True)
_plot_on(ax_ne, add_main_cbar=False, add_well_cbar=False)
_plot_on(ax_sw, add_main_cbar=False, add_well_cbar=False)
_plot_on(ax_se_ax, add_main_cbar=False, add_well_cbar=False)
# clearer view labels
ax_nw.set_title(f"{title} — looking NW")
ax_ne.set_title(f"{title} — looking NE")
ax_sw.set_title(f"{title} — looking SW")
ax_se_ax.set_title(f"{title} — looking SE")
plt.show()