Source code for compass.plugin.ordinance
"""Helper classes for ordinance plugins"""
import asyncio
import logging
from warnings import warn
from textwrap import dedent
from itertools import chain
from functools import cached_property, partial
from abc import ABC, abstractmethod
from contextlib import contextmanager
import pandas as pd
from elm import ApiBase
from compass.llm.calling import (
LLMCaller,
BaseLLMCaller,
ChatLLMCaller,
JSONFromTextLLMCaller,
)
from compass.plugin.interface import (
BaseHeuristic,
BaseTextCollector,
FilteredExtractionPlugin,
)
from compass.services.threaded import CLEANED_FP_REGISTRY
from compass.extraction import extract_ordinance_values
from compass.utilities.enums import LLMTasks, LLMUsageCategory
from compass.utilities.ngrams import convert_text_to_sentence_ngrams
from compass.utilities.parsing import (
clean_backticks_from_llm_response,
extract_year_from_doc_attrs,
merge_overlapping_texts,
)
from compass.utilities import num_ordinances_dataframe
from compass.warn import COMPASSWarning
from compass.exceptions import (
COMPASSPluginConfigurationError,
COMPASSRuntimeError,
)
from compass.pb import COMPASS_PB
logger = logging.getLogger(__name__)
EXCLUDE_FROM_ORD_DOC_CHECK = {
# if doc only contains these, it's not good enough to count as an
# ordinance. Note that prohibitions are explicitly not on this list
"color",
"decommissioning",
"lighting",
"visual impact",
"glare",
"repowering",
"fencing",
"climbing prevention",
"signage",
"soil",
"primary use districts",
"special use districts",
"accessory use districts",
}
[docs]
class BaseTextExtractor(BaseLLMCaller, ABC):
"""Extract succinct extraction text from input"""
TASK_DESCRIPTION = "Condensing text for extraction"
"""Task description to show in progress bar"""
TASK_ID = "text_extraction"
"""ID to use for this extraction for linking with LLM configs"""
_USAGE_LABEL = LLMUsageCategory.DOCUMENT_ORDINANCE_SUMMARY
@property
@abstractmethod
def IN_LABEL(self): # noqa: N802
"""str: Identifier for text ingested by this class"""
raise NotImplementedError
@property
@abstractmethod
def OUT_LABEL(self): # noqa: N802
"""str: Identifier for final text extracted by this class"""
raise NotImplementedError
@property
@abstractmethod
def parsers(self):
"""Generator: Generator of (key, extractor) pairs
`extractor` should be an async callable that accepts a list of
text chunks and returns the shortened (succinct) text to be used
for extraction. The `key` should be a string identifier for the
text returned by the extractor. Multiple (key, extractor) pairs
can be chained in generator order to iteratively refine the
text for extraction.
"""
raise NotImplementedError
[docs]
class BaseParser(ABC):
"""Extract succinct extraction text from input"""
TASK_ID = "data_extraction"
"""ID to use for this extraction for linking with LLM configs"""
@property
@abstractmethod
def IN_LABEL(self): # noqa: N802
"""str: Identifier for text ingested by this class"""
raise NotImplementedError
@property
@abstractmethod
def OUT_LABEL(self): # noqa: N802
"""str: Identifier for final structured data output"""
raise NotImplementedError
[docs]
@abstractmethod
async def parse(self, text):
"""Parse text and extract structured data
Parameters
----------
text : str
Text which may or may not contain information relevant to
the current extraction.
Returns
-------
pandas.DataFrame or None
DataFrame containing structured extracted data. Can also
be ``None`` if no relevant values can be parsed from the
text.
"""
raise NotImplementedError
[docs]
class KeywordBasedHeuristic(BaseHeuristic, ABC):
"""Perform a heuristic check for mention of a technology in text"""
_GOOD_ACRONYM_CONTEXTS = [
" {acronym} ",
" {acronym}\n",
" {acronym}.",
"\n{acronym} ",
"\n{acronym}.",
"\n{acronym}\n",
"({acronym} ",
" {acronym})",
]
[docs]
def check(self, text, match_count_threshold=1):
"""Check for mention of a tech in text
This check first strips the text of any tech "look-alike" words
(e.g. "window", "windshield", etc for "wind" technology). Then,
it checks for particular keywords, acronyms, and phrases that
pertain to the tech in the text. If enough keywords are mentions
(as dictated by `match_count_threshold`), this check returns
``True``.
Parameters
----------
text : str
Input text that may or may not mention the technology of
interest.
match_count_threshold : int, optional
Number of keywords that must match for the text to pass this
heuristic check. Count must be strictly greater than this
value. By default, ``1``.
Returns
-------
bool
``True`` if the number of keywords/acronyms/phrases detected
exceeds the `match_count_threshold`.
"""
heuristics_text = self._convert_to_heuristics_text(text)
total_keyword_matches = self._count_single_keyword_matches(
heuristics_text
)
total_keyword_matches += self._count_acronym_matches(heuristics_text)
total_keyword_matches += self._count_phrase_matches(heuristics_text)
return total_keyword_matches > match_count_threshold
def _convert_to_heuristics_text(self, text):
"""Convert text for heuristic content parsing"""
heuristics_text = text.casefold()
for word in self.NOT_TECH_WORDS:
heuristics_text = heuristics_text.replace(word, "")
return heuristics_text
def _count_single_keyword_matches(self, heuristics_text):
"""Count number of good tech keywords that appear in text"""
return sum(
keyword in heuristics_text for keyword in self.GOOD_TECH_KEYWORDS
)
def _count_acronym_matches(self, heuristics_text):
"""Count number of good tech acronyms that appear in text"""
acronym_matches = 0
for context in self._GOOD_ACRONYM_CONTEXTS:
acronym_keywords = {
context.format(acronym=acronym)
for acronym in self.GOOD_TECH_ACRONYMS
}
acronym_matches = sum(
keyword in heuristics_text for keyword in acronym_keywords
)
if acronym_matches > 0:
break
return acronym_matches
def _count_phrase_matches(self, heuristics_text):
"""Count number of good tech phrases that appear in text"""
text_ngrams = {}
total = 0
for phrase in self.GOOD_TECH_PHRASES:
n = len(phrase.split(" "))
if n <= 1:
msg = (
"Make sure your GOOD_TECH_PHRASES contain at least 2 "
f"words! Got phrase: {phrase!r}"
)
warn(msg, COMPASSWarning)
continue
if n not in text_ngrams:
text_ngrams[n] = set(
convert_text_to_sentence_ngrams(heuristics_text, n)
)
test_ngrams = ( # fmt: off
convert_text_to_sentence_ngrams(phrase, n)
+ convert_text_to_sentence_ngrams(f"{phrase}s", n)
)
if any(t in text_ngrams[n] for t in test_ngrams):
total += 1
return total
@property
@abstractmethod
def NOT_TECH_WORDS(self): # noqa: N802
""":class:`~collections.abc.Iterable`: Not tech keywords"""
raise NotImplementedError
@property
@abstractmethod
def GOOD_TECH_KEYWORDS(self): # noqa: N802
""":class:`~collections.abc.Iterable`: Tech keywords"""
raise NotImplementedError
@property
@abstractmethod
def GOOD_TECH_ACRONYMS(self): # noqa: N802
""":class:`~collections.abc.Iterable`: Tech acronyms"""
raise NotImplementedError
@property
@abstractmethod
def GOOD_TECH_PHRASES(self): # noqa: N802
""":class:`~collections.abc.Iterable`: Tech phrases"""
raise NotImplementedError
[docs]
class PromptBasedTextCollector(JSONFromTextLLMCaller, BaseTextCollector, ABC):
"""Text extractor based on a chain of prompts"""
@property
@abstractmethod
def PROMPTS(self): # noqa: N802
"""list: List of dicts defining the prompts for text extraction
Each dict in the list should have the following keys:
- **prompt**: [REQUIRED] The text filter prompt to use
to determine if a chunk of text is relevant for the
current extraction task. The prompt must instruct the LLM
to return a dictionary (as JSON) with at least one key
that outputs the filter decision. The prompt may use the
following placeholders, which will be filled in with the
corresponding class attributes when the prompt is applied:
- ``"{key}"``: The key corresponding to this prompt.
- **key**: [REQUIRED] A string identifier for the key that
in the output JSON dictionary that represents the LLM
filter decision (``True`` if the tech chunk should be
kept, and ``False`` otherwise).
- **label**: [OPTIONAL] A string label describing the type
of relevant text this prompt is looking for (e.g. "wind
energy conversion system ordinance text"). This is only
used for logging purposes and does not affect the
extraction process itself. If not provided, this will
default to "collector step {i}".
The prompts will be applied in the order they appear in the
list, with the output text from each prompt being fed as input
to the next prompt in the chain. If any of the filter decisions
return ``False``, the text will be discarded and not passed to
subsequent prompts. The final output of the last prompt will
determine wether or not the chunk of text being evaluated is
kept as relevant text for extraction.
"""
raise NotImplementedError
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._chunks = {}
@property
def relevant_text(self):
"""str: Combined ordinance text from the individual chunks"""
if not self._chunks:
logger.debug(
"No relevant ordinance chunk(s) found in original text",
)
return ""
logger.debug(
"Grabbing %d ordinance chunk(s) from original text at these "
"indices: %s",
len(self._chunks),
list(self._chunks),
)
text = [self._chunks[ind] for ind in sorted(self._chunks)]
return merge_overlapping_texts(text)
[docs]
async def check_chunk(self, chunk_parser, ind):
"""Check a chunk at a given ind to see if it contains ordinance
Parameters
----------
chunk_parser : ParseChunksWithMemory
Instance that contains a ``parse_from_ind`` method.
ind : int
Index of the chunk to check.
Returns
-------
bool
Boolean flag indicating whether or not the text in the chunk
contains large wind energy conversion system ordinance text.
"""
for collection_step, prompt_dict in enumerate(self.PROMPTS):
key = prompt_dict["key"]
prompt = prompt_dict["prompt"].format(key=key)
label = prompt_dict.get("label", collection_step)
passed_filter = await chunk_parser.parse_from_ind(
ind,
key=key,
llm_call_callback=self._check_chunk_with_prompt,
prompt=prompt,
)
if not passed_filter:
logger.debug(
"Text at ind %d did not pass collection step: %s",
ind,
label,
)
return False
logger.debug(
"Text at ind %d passed collection step: %s ", ind, label
)
self._store_chunk(chunk_parser, ind)
logger.debug("Added text chunk at ind %d to extraction text", ind)
return True
async def _check_chunk_with_prompt(self, key, text_chunk, prompt):
"""Call LLM on a chunk of text to check for ordinance"""
content = await self.call(
sys_msg=prompt.format(key=key),
content=text_chunk,
usage_sub_label=LLMUsageCategory.DOCUMENT_CONTENT_VALIDATION,
)
logger.debug("LLM response: %s", content)
return content.get(key, False)
def _store_chunk(self, parser, chunk_ind):
"""Store chunk and its neighbors if it is not already stored"""
for offset in range(1 - parser.num_to_recall, 2):
ind_to_grab = chunk_ind + offset
if ind_to_grab < 0 or ind_to_grab >= len(parser.text_chunks):
continue
self._chunks.setdefault(
ind_to_grab, parser.text_chunks[ind_to_grab]
)
[docs]
class PromptBasedTextExtractor(LLMCaller, BaseTextExtractor, ABC):
"""Text extractor based on a chain of prompts"""
SYSTEM_MESSAGE = (
dedent(
"""\
You are a text extraction assistant. Your job is to extract only
verbatim, **unmodified** excerpts from the provided text. Do not
interpret or paraphrase. Do not summarize. Only return exactly copied
segments that match the specified scope. If the relevant content
appears within a table, return the entire table, including headers
and footers, exactly as formatted.
"""
)
.replace("\n", " ")
.strip()
)
"""System message for text extraction LLM calls"""
FORMATTING_PROMPT = (
dedent(
"""\
## Formatting & Structure ##:
- **Preserve _all_ section titles, headers, and numberings** for
reference.
- **Maintain the original wording, formatting, and structure** to
ensure accuracy.
"""
)
.replace("\n ", " ")
.strip()
)
"""Prompt component instructing model to preserve text structure"""
OUTPUT_PROMPT = (
dedent(
"""\
## Output Handling ##:
- This is a strict extraction task — act like a text filter, **not**
a summarizer or writer.
- Do not add, explain, reword, or summarize anything.
- The output must be a **copy-paste** of the original excerpt.
**Absolutely no paraphrasing or rewriting.**
- The output must consist **only** of contiguous or discontiguous
verbatim blocks copied from the input.
- The only allowed change is to remove irrelevant sections of text.
You can remove irrelevant text from within sections, but you cannot
add any new text or modify the text you keep in any way.
- If **no relevant text** is found, return the response:
'No relevant text.'
"""
)
.replace("\n ", " ")
.strip()
)
"""Prompt component instructing model output guidelines"""
@property
@abstractmethod
def PROMPTS(self): # noqa: N802
"""list: List of dicts defining the prompts for text extraction
Each dict in the list should have the following keys:
- **prompt**: [REQUIRED] The text extraction prompt to use
for the extraction. The prompt may use the following
placeholders, which will be filled in with the
corresponding class attributes when the prompt is applied:
- ``"{FORMATTING_PROMPT}"``: The
:obj:`PromptBasedTextExtractor.FORMATTING_PROMPT`
class attribute, which provides instructions for
preserving the formatting and structure of the
extracted text.
- ``"{OUTPUT_PROMPT}"``: The
:obj:`PromptBasedTextExtractor.OUTPUT_PROMPT`
class attribute, which provides instructions for
how the model should format the output and what
content to include or exclude.
- **key**: [OPTIONAL] A string identifier for the text
extracted by this prompt. If not provided, a default key
``"extracted_text_{i}"`` will be used, where ``{i}`` is
the index of the prompt in the list. The value of this key
from the last dictionary in the input list will be used as
this extractor's `OUT_LABEL`, which is typically used to
link the extracted text to the appropriate parser via the
parser's `IN_LABEL`. All `key` values should be unique
across all prompts in the chain.
- **out_fn**: [OPTIONAL] A file name template that will be
used to write the extracted text to a file. The template
can include the placeholder ``{jurisdiction}``, which
will be replaced with the full jurisdiction name. If not
provided, the extracted text will not be written to a
file. This is primarily intended for debugging and
analysis purposes, and is not required for the extraction
process itself.
The prompts will be applied in the order they appear in the
list, with the output text from each prompt being fed as input
to the next prompt in the chain. The final output of the last
prompt will be the output of the extractor.
"""
raise NotImplementedError
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
if getattr(cls, "__abstractmethods__", None):
return
last_prompt = cls.PROMPTS[-1]
last_index = len(cls.PROMPTS) - 1
cls.OUT_LABEL = last_prompt.get("key", f"extracted_text_{last_index}")
@property
def parsers(self):
"""Iterable of parsers provided by this extractor
Yields
------
name : str
Name describing the type of text output by the parser.
parser : callable
Async function that takes a ``text_chunks`` input and
outputs parsed text.
"""
for ind, prompt_dict in enumerate(self.PROMPTS):
key = prompt_dict.get("key", f"extracted_text_{ind}")
instructions = prompt_dict["prompt"].format(
FORMATTING_PROMPT=self.FORMATTING_PROMPT,
OUTPUT_PROMPT=self.OUTPUT_PROMPT,
)
yield key, partial(self._process, instructions=instructions)
async def _process(self, text_chunks, instructions, is_valid_chunk=None):
"""Perform extraction processing"""
if is_valid_chunk is None:
is_valid_chunk = _valid_chunk
logger.info(
"Extracting summary text from %d text chunks asynchronously...",
len(text_chunks),
)
logger.debug("Model instructions are:\n%s", instructions)
outer_task_name = asyncio.current_task().get_name()
summaries = [
asyncio.create_task(
self.call(
sys_msg=self.SYSTEM_MESSAGE,
content=f"{instructions}\n\n# TEXT #\n\n{chunk}",
usage_sub_label=self._USAGE_LABEL,
),
name=outer_task_name,
)
for chunk in text_chunks
]
summary_chunks = await asyncio.gather(*summaries)
summary_chunks = [
clean_backticks_from_llm_response(chunk)
for chunk in summary_chunks
if is_valid_chunk(chunk)
]
text_summary = merge_overlapping_texts(summary_chunks)
logger.debug(
"Final summary contains %d tokens",
ApiBase.count_tokens(
text_summary, model=self.kwargs.get("model", "gpt-4")
),
)
return text_summary
[docs]
class OrdinanceParser(BaseLLMCaller, BaseParser):
"""Base class for parsing structured data"""
def _init_chat_llm_caller(self, system_message):
"""Initialize a ChatLLMCaller instance for the DecisionTree"""
return ChatLLMCaller(
self.llm_service,
system_message=system_message,
usage_tracker=self.usage_tracker,
**self.kwargs,
)
[docs]
class OrdinanceExtractionPlugin(FilteredExtractionPlugin):
"""Base class for COMPASS extraction plugins
This class provides a good balance between ease of use and
extraction flexibility, allowing implementers to provide additional
functionality during the extraction process.
Plugins can hook into various stages of the extraction pipeline
to modify behavior, add custom processing, or integrate with
external systems.
Subclasses should implement the desired hooks and override
methods as needed.
"""
ALLOW_MULTI_DOC_EXTRACTION = False
"""bool: Whether to allow extraction over multiple documents"""
@property
@abstractmethod
def TEXT_EXTRACTORS(self): # noqa: N802
"""list of BaseTextExtractor: Classes to condense text
Should be an iterable of one or more classes to condense text in
preparation for the extraction task.
"""
raise NotImplementedError
@property
@abstractmethod
def PARSERS(self): # noqa: N802
"""list of BaseParser: Classes to extract structured data
Should be an iterable of one or more classes to extract
structured data from text.
"""
raise NotImplementedError
[docs]
@cached_property
def producers(self):
"""list: All classes that produce attributes on the doc"""
return chain(self.PARSERS, self.TEXT_EXTRACTORS, self.TEXT_COLLECTORS)
[docs]
@cached_property
def consumer_producer_pairs(self):
"""list: Pairs of (consumer, producer) for IN/OUT validation"""
return [
(self.PARSERS, chain(self.TEXT_EXTRACTORS, self.TEXT_COLLECTORS)),
(self.TEXT_EXTRACTORS, self.TEXT_COLLECTORS),
]
[docs]
async def extract_ordinances_from_text(
self, doc, parser_class, model_config
):
"""Extract structured data from input text
The extracted structured data will be stored in the ``.attrs``
dictionary of the input document under the
``parser_class.OUT_LABEL`` key.
Parameters
----------
doc : BaseDocument
Document containing text to extract structured data from.
parser_class : BaseParser
Class to use for structured data extraction.
model_config : LLMConfig
Configuration for the LLM model to use for structured data
extraction.
"""
parser = parser_class(
llm_service=model_config.llm_service,
usage_tracker=self.usage_tracker,
**model_config.llm_call_kwargs,
)
logger.info(
"Extracting %s...", parser_class.OUT_LABEL.replace("_", " ")
)
await extract_ordinance_values(
doc,
parser,
text_key=parser_class.IN_LABEL,
out_key=parser_class.OUT_LABEL,
)
[docs]
@classmethod
def get_structured_data_row_count(cls, data_df):
"""Get the number of data rows extracted from a document
Parameters
----------
data_df : pandas.DataFrame or None
DataFrame to check for extracted structured data.
Returns
-------
int
Number of data rows extracted from the document.
"""
if data_df is None:
return 0
return num_ordinances_dataframe(
data_df, exclude_features=EXCLUDE_FROM_ORD_DOC_CHECK
)
[docs]
async def parse_docs_for_structured_data(self, extraction_context):
"""Parse documents to extract structured data/information
Parameters
----------
extraction_context : ExtractionContext
Context containing candidate documents to parse.
Returns
-------
ExtractionContext or None
Context with extracted data/information stored in the
``.attrs`` dictionary, or ``None`` if no data was extracted.
"""
if self.ALLOW_MULTI_DOC_EXTRACTION:
return await self.parse_multi_doc_context_for_structured_data(
extraction_context
)
return await self.parse_single_doc_for_structured_data(
extraction_context
)
[docs]
async def parse_multi_doc_context_for_structured_data(
self, extraction_context
):
"""Parse all documents to extract structured data/information
Parameters
----------
extraction_context : ExtractionContext
Context containing candidate documents to parse. The text
from all documents will be concatenated to create the
context for the extraction.
Returns
-------
ExtractionContext or None
Context with extracted data/information stored in the
``.attrs`` dictionary, or ``None`` if no data was extracted.
"""
key = self.TEXT_COLLECTORS[-1].OUT_LABEL
extraction_context.attrs[key] = extraction_context.multi_doc_context(
attr_text_key=key
)
data_df = await self.parse_for_structured_data(extraction_context)
row_count = self.get_structured_data_row_count(data_df)
if row_count == 0:
logger.debug(
"No extracted data; searched %d docs",
extraction_context.num_documents,
)
return None
data_df = await _fill_out_multi_file_sources(
data_df,
extraction_context,
out_fn_stem=self.jurisdiction.full_name,
)
extraction_context.attrs["structured_data"] = data_df
logger.info(
"%d ordinance value(s) found in %d docs for %s. ",
num_ordinances_dataframe(data_df),
extraction_context.num_documents,
self.jurisdiction.full_name,
)
return extraction_context
[docs]
async def parse_single_doc_for_structured_data(self, extraction_context):
"""Parse documents one at a time to extract structured data
The first document to return some extracted data will be marked
as the source and will be returned from this method.
Parameters
----------
extraction_context : ExtractionContext
Context containing candidate documents to parse.
Returns
-------
ExtractionContext or None
Context with extracted data/information stored in the
``.attrs`` dictionary, or ``None`` if no data was extracted.
"""
for doc_for_extraction in extraction_context:
data_df = await self.parse_for_structured_data(doc_for_extraction)
row_count = self.get_structured_data_row_count(data_df)
if row_count > 0:
data_df["source"] = doc_for_extraction.attrs.get("source")
data_df["year"] = extract_year_from_doc_attrs(
doc_for_extraction.attrs
)
await extraction_context.mark_doc_as_data_source(
doc_for_extraction, out_fn_stem=self.jurisdiction.full_name
)
extraction_context.attrs["structured_data"] = data_df
logger.info(
"%d ordinance value(s) found in doc from %s for %s. ",
num_ordinances_dataframe(data_df),
doc_for_extraction.attrs.get("source", "unknown source"),
self.jurisdiction.full_name,
)
return extraction_context
logger.debug(
"No ordinances found; searched %d docs",
extraction_context.num_documents,
)
return None
[docs]
async def parse_for_structured_data(self, source):
"""Extract all possible structured data from a document
This method is called from the default implementation of
`parse_single_doc_for_structured_data()` for each document that
passed filtering. If you overwrite
``parse_single_doc_for_structured_data()``, you can ignore this
method.
Parameters
----------
source : BaseDocument or ExtractionContext
Source to extract structured data from. Must have an
`.attrs` attribute that contains text from which data should
be extracted.
Returns
-------
pandas.DataFrame or None
DataFrame containing extracted structured data, or None if
no structured data were extracted.
"""
with self._tracked_progress():
tasks = [
asyncio.create_task(
self._try_extract_ordinances(source, parser_class),
name=self.jurisdiction.full_name,
)
for parser_class in filter(None, self.PARSERS)
]
await asyncio.gather(*tasks)
return self._concat_scrape_results(source)
async def _try_extract_ordinances(self, doc_for_extraction, parser_class):
"""Apply a single extractor and parser to legal text"""
if parser_class.IN_LABEL not in doc_for_extraction.attrs:
await self._run_text_extractors(doc_for_extraction, parser_class)
model_config = self._get_model_config(
primary_key=parser_class.TASK_ID,
secondary_key=LLMTasks.DATA_EXTRACTION,
)
await self.extract_ordinances_from_text(
doc_for_extraction,
parser_class=parser_class,
model_config=model_config,
)
await self.record_usage()
async def _run_text_extractors(self, doc_for_extraction, parser_class):
"""Run text extractor(s) on document to get text for a parser"""
te = [
te
for te in self.TEXT_EXTRACTORS
if te.OUT_LABEL == parser_class.IN_LABEL
]
if len(te) != 1:
msg = (
f"Could not find unique text extractor for parser "
f"{parser_class.__name__} with IN_LABEL "
f"{parser_class.IN_LABEL!r}. Got matches: {te}"
)
raise COMPASSPluginConfigurationError(msg)
te = te[0]
model_config = self._get_model_config(
primary_key=te.TASK_ID,
secondary_key=LLMTasks.TEXT_EXTRACTION,
)
logger.debug(
"Condensing text for extraction using %r for doc from %s",
te.__name__,
doc_for_extraction.attrs.get("source", "unknown source"),
)
assert self._jsp is not None, "No progress bar set!"
task_id = self._jsp.add_task(te.TASK_DESCRIPTION)
await self.extract_relevant_text(doc_for_extraction, te, model_config)
await self.record_usage()
self._jsp.remove_task(task_id)
@contextmanager
def _tracked_progress(self):
"""Context manager to set up jurisdiction sub-progress bar"""
loc = self.jurisdiction.full_name
with COMPASS_PB.jurisdiction_sub_prog(loc) as self._jsp:
yield
self._jsp = None
def _concat_scrape_results(self, source):
"""Concatenate structured data from all parsers"""
data = [source.attrs.get(p.OUT_LABEL, None) for p in self.PARSERS]
data = [df for df in data if df is not None and not df.empty]
if len(data) == 0:
return None
return data[0] if len(data) == 1 else pd.concat(data)
def _get_model_config(self, primary_key, secondary_key):
"""Get model config: primary_key -> secondary_key -> default"""
if primary_key in self.model_configs:
return self.model_configs[primary_key]
return self.model_configs.get(
secondary_key, self.model_configs[LLMTasks.DEFAULT]
)
def validate_plugin_configuration(self):
"""[NOT PUBLIC API] Validate plugin is properly configured"""
super().validate_plugin_configuration()
self._validate_text_extractors()
self._validate_parsers()
self._validate_in_out_keys()
self._validate_collector_prompts()
self._validate_extractor_prompts()
self._register_clean_file_names()
def _validate_text_extractors(self):
"""Validate user provided at least one text extractor class"""
try:
extractors = self.TEXT_EXTRACTORS
except NotImplementedError:
msg = (
f"Plugin class {self.__class__.__name__} is missing required "
"property 'TEXT_EXTRACTORS'"
)
raise COMPASSPluginConfigurationError(msg) from None
if len(extractors) == 0:
msg = (
f"Plugin class {self.__class__.__name__} has an empty "
"'TEXT_EXTRACTORS' property! Please provide at least "
"one text extractor class."
)
raise COMPASSPluginConfigurationError(msg)
for extractor_class in extractors:
if not issubclass(extractor_class, BaseTextExtractor):
msg = (
f"Plugin class {self.__class__.__name__} has invalid "
"entry in 'TEXT_EXTRACTORS' property: All entries must "
"be subclasses of "
"compass.plugin.ordinance.BaseTextExtractor, but "
f"{extractor_class.__name__} is not!"
)
raise COMPASSPluginConfigurationError(msg)
def _validate_parsers(self):
"""Validate user provided at least one parser class"""
try:
parsers = self.PARSERS
except NotImplementedError:
msg = (
f"Plugin class {self.__class__.__name__} is missing required "
"property 'PARSERS'"
)
raise COMPASSPluginConfigurationError(msg) from None
if len(parsers) == 0:
msg = (
f"Plugin class {self.__class__.__name__} has an empty "
"'PARSERS' property! Please provide at least "
"one text extractor class."
)
raise COMPASSPluginConfigurationError(msg)
for parsers_class in parsers:
if not issubclass(parsers_class, BaseParser):
msg = (
f"Plugin class {self.__class__.__name__} has invalid "
"entry in 'PARSERS' property: All entries must "
"be subclasses of "
"compass.plugin.ordinance.BaseParser, but "
f"{parsers_class.__name__} is not!"
)
raise COMPASSPluginConfigurationError(msg)
def _validate_in_out_keys(self):
"""Validate that all IN_LABELs have matching OUT_LABELs"""
out_keys = {}
for producer in self.producers:
out_keys.setdefault(producer.OUT_LABEL, []).append(producer)
dupes = {k: v for k, v in out_keys.items() if len(v) > 1}
if dupes:
formatted = "\n".join(
[
f"{key}: {[cls.__name__ for cls in classes]}"
for key, classes in dupes.items()
]
)
msg = (
"Multiple processing classes produce the same OUT_LABEL key:\n"
f"{formatted}"
)
raise COMPASSPluginConfigurationError(msg)
for consumers, producers in self.consumer_producer_pairs:
_validate_in_out_keys(consumers, producers)
def _validate_collector_prompts(self):
"""Validate that all text collectors have prompts defined"""
for collector in self.TEXT_COLLECTORS:
if not issubclass(collector, PromptBasedTextCollector):
continue
try:
num_prompts = len(collector.PROMPTS)
except NotImplementedError:
msg = (
f"Text collector {self.__class__.__name__} is missing "
"required property 'PROMPTS'"
)
raise COMPASSPluginConfigurationError(msg) from None
if num_prompts == 0:
msg = (
f"Text collector {self.__class__.__name__} has an empty "
"'PROMPTS' property! Please provide at least one prompt "
"dictionary."
)
raise COMPASSPluginConfigurationError(msg)
def _validate_extractor_prompts(self):
"""Validate that all text extractors have prompts defined"""
for collector in self.TEXT_EXTRACTORS:
if not issubclass(collector, PromptBasedTextExtractor):
continue
try:
num_prompts = len(collector.PROMPTS)
except NotImplementedError:
msg = (
f"Text extractor {self.__class__.__name__} is missing "
"required property 'PROMPTS'"
)
raise COMPASSPluginConfigurationError(msg) from None
if num_prompts == 0:
msg = (
f"Text extractor {self.__class__.__name__} has an empty "
"'PROMPTS' property! Please provide at least one prompt "
"dictionary."
)
raise COMPASSPluginConfigurationError(msg)
def _register_clean_file_names(self):
"""Register file names for writing cleaned text outputs"""
CLEANED_FP_REGISTRY.setdefault(self.IDENTIFIER.casefold(), {})
for extractor_class in self.TEXT_EXTRACTORS:
if not issubclass(extractor_class, PromptBasedTextExtractor):
continue
for ind, prompt_dict in enumerate(extractor_class.PROMPTS):
out_fn = prompt_dict.get("out_fn", None)
if not out_fn:
continue
key = prompt_dict.get("key", f"extracted_text_{ind}")
CLEANED_FP_REGISTRY[self.IDENTIFIER.casefold()][key] = out_fn
def _valid_chunk(chunk):
"""True if chunk has content"""
return chunk and "no relevant text" not in chunk.lower()
def _validate_in_out_keys(consumers, producers):
"""Validate that all IN_LABELs have matching OUT_LABELs"""
in_keys = {}
out_keys = {}
for producer_class in producers:
out_keys.setdefault(producer_class.OUT_LABEL, []).append(
producer_class
)
for consumer_class in chain(consumers):
in_keys.setdefault(consumer_class.IN_LABEL, []).append(consumer_class)
for in_key, classes in in_keys.items():
formatted = f"{[cls.__name__ for cls in classes]}"
if in_key not in out_keys:
msg = (
f"One or more processing classes require IN_LABEL "
f"{in_key!r}, which is not produced by any previous "
f"processing class: {formatted}"
)
raise COMPASSPluginConfigurationError(msg)
async def _fill_out_multi_file_sources(
data_df, extraction_context, out_fn_stem
):
"""Fill out source column for multi-doc extraction
This method implements a "report all document" fallback for the
following scenarios:
- source inds not given in output
- source inds not integers
- source inds are invalid indices for the actual documents
If the source inds are all valid, each row in the dataframe gets its
own unique source and year combo.
"""
try:
source_inds = _get_source_inds(
data_df, extraction_context.num_documents
)
except COMPASSRuntimeError:
return await _fill_in_all_sources(
data_df, extraction_context, out_fn_stem
)
year_map = {}
source_map = {}
for source_ind in source_inds:
doc = extraction_context[source_ind]
year_map[source_ind] = extract_year_from_doc_attrs(doc.attrs)
source_map[source_ind] = doc.attrs.get("source")
await extraction_context.mark_doc_as_data_source(
doc, out_fn_stem=f"{out_fn_stem}_{source_ind + 1}"
)
data_df["year"] = data_df["source"].map(
lambda source_ind: (
year_map.get(int(source_ind)) if pd.notna(source_ind) else None
)
)
data_df["source"] = data_df["source"].map(
lambda source_ind: (
source_map.get(int(source_ind)) if pd.notna(source_ind) else None
)
)
return data_df
def _get_source_inds(data_df, num_docs):
"""Try to extract source document indices"""
if "source" not in data_df.columns:
msg = "'source' column not found in extracted outputs"
raise COMPASSRuntimeError(msg)
try:
source_inds = data_df["source"].dropna().unique().astype(int)
except (TypeError, ValueError):
msg = "'source' column contains non-integer values"
raise COMPASSRuntimeError(msg) from None
if any(
source_ind < 0 or source_ind >= num_docs for source_ind in source_inds
):
msg = "'source' column contains out-of-bounds indices"
raise COMPASSRuntimeError(msg)
return source_inds
async def _fill_in_all_sources(data_df, extraction_context, out_fn_stem):
"""Fill in source and year columns using all sources"""
logger.debug(
"Filling in sources using all %d documents in context due to "
"invalid or missing source indices",
extraction_context.num_documents,
)
all_sources = filter(
None, [doc.attrs.get("source") for doc in extraction_context]
)
concat_sources = " ;\n".join(all_sources) or None
data_df["source"] = concat_sources
years = list(
filter(
None,
[
extract_year_from_doc_attrs(doc.attrs)
for doc in extraction_context
],
)
)
data_df["year"] = max(years) if years else None
for ind, doc in enumerate(extraction_context, start=1):
await extraction_context.mark_doc_as_data_source(
doc, out_fn_stem=f"{out_fn_stem}_{ind}"
)
return data_df