Source code for compass.plugin.interface
"""COMPASS extraction plugin base class"""
import logging
from abc import ABC, abstractmethod
from compass.plugin.base import BaseExtractionPlugin
from compass.llm.calling import BaseLLMCaller
from compass.extraction import extract_relevant_text_with_ngram_validation
from compass.scripts.download import filter_ordinance_docs
from compass.services.threaded import CLEANED_FP_REGISTRY, CleanedFileWriter
from compass.utilities import doc_infos_to_db, save_db
from compass.exceptions import COMPASSPluginConfigurationError
logger = logging.getLogger(__name__)
[docs]
class BaseHeuristic(ABC):
"""Base class for a heuristic check"""
[docs]
@abstractmethod
def check(self, text):
"""Check for mention of a tech in text (or text chunk)
Parameters
----------
text : str
Input text that may or may not mention the technology of
interest.
Returns
-------
bool
``True`` if the text passes the heuristic check and
``False`` otherwise.
"""
raise NotImplementedError
[docs]
class BaseTextCollector(BaseLLMCaller, ABC):
"""Base class for text collectors that gather relevant text"""
@property
@abstractmethod
def OUT_LABEL(self): # noqa: N802
"""str: Identifier for text collected by this class"""
raise NotImplementedError
@property
@abstractmethod
def relevant_text(self):
"""str: Combined relevant text from the individual chunks"""
raise NotImplementedError
[docs]
@abstractmethod
async def check_chunk(self, chunk_parser, ind):
"""Check if a text chunk is relevant for extraction
You should validate chunks like so::
is_correct_kind_of_text = await chunk_parser.parse_from_ind(
ind,
key="my_unique_validation_key",
llm_call_callback=my_async_llm_call_function,
)
where the `"key"` is unique to this particular validation (it
will be used to cache the validation result in the chunk
parser's memory) and `my_async_llm_call_function` is an async
function that takes in a key and text chunk and returns a
boolean indicating whether or not the text chunk passes the
validation. You can call `chunk_parser.parse_from_ind` as many
times as you want within this method, but be sure to use unique
keys for each validation.
Parameters
----------
chunk_parser : ParseChunksWithMemory
Instance that contains a ``parse_from_ind`` method.
ind : int
Index of the chunk to check.
Returns
-------
bool
Boolean flag indicating whether or not the text in the chunk
contains information relevant to the extraction task.
See Also
--------
:func:`~compass.validation.content.ParseChunksWithMemory.parse_from_ind`
Method used to parse text from a chunk with memory of prior
chunk validations.
"""
raise NotImplementedError
[docs]
class FilteredExtractionPlugin(BaseExtractionPlugin):
"""Base class for COMPASS extraction plugins
This class provides the standard COMPASS document filtering and text
collection pipeline, allowing implementers to focus primarily on the
structured data extraction step. Filtering and text collection is
provided by subclassing the `BaseTextCollector` class and setting
the `TEXT_COLLECTORS` property to a list of the desired text
collectors.
Plugins can hook into various stages of the extraction pipeline
to modify behavior, add custom processing, or integrate with
external systems.
Subclasses should implement the desired hooks and override
methods as needed.
"""
@property
@abstractmethod
def IDENTIFIER(self): # noqa: N802
"""str: Identifier for extraction task (e.g. "water rights")"""
raise NotImplementedError
@property
@abstractmethod
def QUERY_TEMPLATES(self): # noqa: N802
"""list: List of search engine query templates for extraction
Query templates can contain the placeholder ``{jurisdiction}``
which will be replaced with the full jurisdiction name during
the search engine query.
"""
raise NotImplementedError
@property
@abstractmethod
def WEBSITE_KEYWORDS(self): # noqa: N802
"""list: List of keywords
List of keywords that indicate links which should be prioritized
when performing a website scrape for a document.
"""
raise NotImplementedError
@property
@abstractmethod
def TEXT_COLLECTORS(self): # noqa: N802
"""list of BaseTextCollector: Classes to collect text
Should be an iterable of one or more classes to collect text
for the extraction task.
"""
raise NotImplementedError
@property
@abstractmethod
def HEURISTIC(self): # noqa: N802
"""BaseHeuristic: Class with a ``check()`` method
The ``check()`` method should accept a string of text and
return ``True`` if the text passes the heuristic check and
``False`` otherwise.
"""
raise NotImplementedError
[docs]
@classmethod
def save_structured_data(cls, doc_infos, out_dir):
"""Write extracted water rights data to disk
Parameters
----------
doc_infos : list of dict
List of dictionaries containing the following keys:
- "jurisdiction": An initialized Jurisdiction object
representing the jurisdiction that was extracted.
- "ord_db_fp": A path to the extracted structured data
stored on disk, or ``None`` if no data was extracted.
out_dir : path-like
Path to the output directory for the data.
Returns
-------
int
Number of unique jurisdictions that information was
found/written for.
"""
db, num_docs_found = doc_infos_to_db(doc_infos)
save_db(db, out_dir)
return num_docs_found
[docs]
async def pre_filter_docs_hook(self, extraction_context): # noqa: PLR6301
"""Pre-process documents before running them through the filter
Parameters
----------
extraction_context : ExtractionContext
Context with downloaded documents to process.
Returns
-------
ExtractionContext
Context with documents to be passed onto the filtering step.
"""
return extraction_context
[docs]
async def post_filter_docs_hook(self, extraction_context): # noqa: PLR6301
"""Post-process documents after running them through the filter
Parameters
----------
extraction_context : ExtractionContext
Context with documents that passed the filtering step.
Returns
-------
ExtractionContext
Context with documents to be passed onto the parsing step.
"""
return extraction_context
[docs]
async def extract_relevant_text(self, doc, extractor_class, model_config):
"""Condense text for extraction task
This method takes a text extractor and applies it to the
collected document chunks to get a concise version of the text
that can be used for structured data extraction.
The extracted text will be stored in the ``.attrs`` dictionary
of the input document under the ``extractor_class.OUT_LABEL``
key.
Parameters
----------
doc : BaseDocument
Document containing text chunks to condense.
extractor_class : BaseTextExtractor
Class to use for text extraction.
model_config : LLMConfig
Configuration for the LLM model to use for text extraction.
"""
extractor = extractor_class(
llm_service=model_config.llm_service,
usage_tracker=self.usage_tracker,
**model_config.llm_call_kwargs,
)
doc = await extract_relevant_text_with_ngram_validation(
doc,
model_config.text_splitter,
extractor,
original_text_key=extractor_class.IN_LABEL,
)
await self._write_cleaned_text(doc)
[docs]
async def get_query_templates(self):
"""Get a list of search engine query templates for extraction
Query templates can contain the placeholder ``{jurisdiction}``
which will be replaced with the full jurisdiction name during
the search engine query.
"""
return self.QUERY_TEMPLATES
[docs]
async def get_website_keywords(self):
"""Get a dict of website search keyword scores
Dictionary mapping keywords to scores that indicate links which
should be prioritized when performing a website scrape for a
document.
"""
return self.WEBSITE_KEYWORDS
[docs]
async def get_heuristic(self):
"""Get a `BaseHeuristic` instance with a `check()` method
The ``check()`` method should accept a string of text and return
``True`` if the text passes the heuristic check and ``False``
otherwise.
"""
return self.HEURISTIC()
[docs]
async def filter_docs(
self, extraction_context, need_jurisdiction_verification=True
):
"""Filter down candidate documents before parsing
Parameters
----------
extraction_context : ExtractionContext
Context containing candidate documents to be filtered.
need_jurisdiction_verification : bool, optional
Whether to verify that documents pertain to the correct
jurisdiction. By default, ``True``.
Returns
-------
iterable of BaseDocument
Filtered documents or ``None`` if no documents remain.
"""
if not extraction_context:
return None
logger.debug(
"Passing %d document(s) in to `pre_filter_docs_hook` ",
extraction_context.num_documents,
)
docs = await self.pre_filter_docs_hook(extraction_context.documents)
logger.debug(
"%d document(s) remaining after `pre_filter_docs_hook` for "
"%s\n\t- %s",
len(docs),
self.jurisdiction.full_name,
"\n\t- ".join(
[doc.attrs.get("source", "Unknown source") for doc in docs]
),
)
heuristic = await self.get_heuristic()
docs = await filter_ordinance_docs(
docs,
self.jurisdiction,
self.model_configs,
heuristic=heuristic,
tech=self.IDENTIFIER,
text_collectors=self.TEXT_COLLECTORS,
usage_tracker=self.usage_tracker,
check_for_correct_jurisdiction=need_jurisdiction_verification,
)
if not docs:
return None
logger.debug(
"Passing %d document(s) in to `post_filter_docs_hook` ", len(docs)
)
docs = await self.post_filter_docs_hook(docs)
logger.debug(
"%d document(s) remaining after `post_filter_docs_hook` for "
"%s\n\t- %s",
len(docs),
self.jurisdiction.full_name,
"\n\t- ".join(
[doc.attrs.get("source", "Unknown source") for doc in docs]
),
)
if not docs:
return None
extraction_context.documents = docs
return extraction_context
async def _write_cleaned_text(self, doc):
"""Write cleaned text to `clean_files` dir"""
out_fp = await CleanedFileWriter.call(
doc, self.IDENTIFIER, self.jurisdiction.full_name
)
doc.attrs["cleaned_fps"] = out_fp
return doc
def validate_plugin_configuration(self):
"""[NOT PUBLIC API] Validate plugin is properly configured"""
self._validate_plugin_identifier()
self._validate_query_templates()
self._validate_website_keywords()
self._validate_text_collectors()
self._register_collected_text_file_names()
def _validate_plugin_identifier(self):
"""Validate that the plugin has a valid IDENTIFIER property"""
try:
__ = self.IDENTIFIER
except NotImplementedError:
msg = (
f"Plugin class {self.__class__.__name__} is missing required "
"property 'IDENTIFIER'"
)
raise COMPASSPluginConfigurationError(msg) from None
def _validate_query_templates(self):
"""Validate that the plugin has valid QUERY_TEMPLATES"""
try:
num_q_templates = len(self.QUERY_TEMPLATES)
except NotImplementedError:
msg = (
f"Plugin class {self.__class__.__name__} is missing required "
"property 'QUERY_TEMPLATES'"
)
raise COMPASSPluginConfigurationError(msg) from None
if num_q_templates == 0:
msg = (
f"Plugin class {self.__class__.__name__} has an empty "
"'QUERY_TEMPLATES' property! Please provide at least "
"one query template."
)
raise COMPASSPluginConfigurationError(msg)
def _validate_website_keywords(self):
"""Validate that the plugin has valid WEBSITE_KEYWORDS"""
try:
num_website_keywords = len(self.WEBSITE_KEYWORDS)
except NotImplementedError:
msg = (
f"Plugin class {self.__class__.__name__} is missing required "
"property 'WEBSITE_KEYWORDS'"
)
raise COMPASSPluginConfigurationError(msg) from None
if num_website_keywords == 0:
msg = (
f"Plugin class {self.__class__.__name__} has an empty "
"'WEBSITE_KEYWORDS' property! Please provide at least "
"one website keyword."
)
raise COMPASSPluginConfigurationError(msg)
def _validate_text_collectors(self):
"""Validate that the plugin has valid TEXT_COLLECTORS"""
try:
collectors = self.TEXT_COLLECTORS
except NotImplementedError:
msg = (
f"Plugin class {self.__class__.__name__} is missing required "
"property 'TEXT_COLLECTORS'"
)
raise COMPASSPluginConfigurationError(msg) from None
if len(collectors) == 0:
msg = (
f"Plugin class {self.__class__.__name__} has an empty "
"'TEXT_COLLECTORS' property! Please provide at least "
"one text collector class."
)
raise COMPASSPluginConfigurationError(msg)
for collector_class in collectors:
if not issubclass(collector_class, BaseTextCollector):
msg = (
f"Plugin class {self.__class__.__name__} has invalid "
"entry in 'TEXT_COLLECTORS' property: All entries must "
"be subclasses of "
"compass.plugin.interface.BaseTextCollector, but "
f"{collector_class.__name__} is not!"
)
raise COMPASSPluginConfigurationError(msg)
def _register_collected_text_file_names(self):
"""Register file names for writing cleaned text outputs"""
CLEANED_FP_REGISTRY.setdefault(self.IDENTIFIER.casefold(), {})
collected_text_key = list(self.TEXT_COLLECTORS)[-1].OUT_LABEL
CLEANED_FP_REGISTRY[self.IDENTIFIER.casefold()][collected_text_key] = (
"{jurisdiction} Collected Text.txt"
)