Source code for compass.plugin.one_shot.base
"""COMPASS one-shot extraction plugin"""
import logging
import importlib.resources
from asyncio import Semaphore
from enum import StrEnum, auto
from compass.llm.calling import SchemaOutputLLMCaller
from compass.plugin import (
register_plugin,
NoOpHeuristic,
NoOpTextCollector,
NoOpTextExtractor,
PromptBasedTextCollector,
PromptBasedTextExtractor,
OrdinanceExtractionPlugin,
KeywordBasedHeuristic,
)
from compass.plugin.one_shot.generators import (
generate_query_templates,
generate_website_keywords,
generate_heuristic_keywords,
)
from compass.plugin.one_shot.components import (
SchemaBasedTextCollector,
SchemaBasedTextExtractor,
SchemaOrdinanceParser,
)
from compass.plugin.one_shot.cache import key_from_cache, key_to_cache
from compass.services.threaded import CLEANED_FP_REGISTRY
from compass.utilities.io import load_config
from compass.utilities.enums import LLMTasks
from compass.exceptions import COMPASSPluginConfigurationError
logger = logging.getLogger(__name__)
_SCHEMA_DIR = importlib.resources.files("compass.plugin.one_shot.schemas")
_QT_SEMAPHORE = Semaphore(1)
_WK_SEMAPHORE = Semaphore(1)
_HK_SEMAPHORE = Semaphore(1)
class _CacheKey(StrEnum):
"""LLM generated content cache keys"""
QUERY_TEMPLATES = auto()
WEBSITE_KEYWORDS = auto()
HEURISTIC_KEYWORDS = auto()
[docs]
def create_schema_based_one_shot_extraction_plugin(config, tech): # noqa: C901
"""Create a one-shot extraction plugin based on a configuration
Parameters
----------
config : dict or path-like
One-shot configuration dictionary. If not a dictionary, should
be a path to a file containing the configuration (supported
formats: JSON, JSON5, YAML, TOML). See the
`wind ordinance schema <https://github.com/NatLabRockies/COMPASS/blob/main/examples/one_shot_schema_extraction/wind_schema.json>`_
for an example. The configuration must include the following
keys:
- `schema`: A dictionary representing the schema of the
output. Can also be a path to a file that contains the
schema (supported formats: JSON, JSON5, YAML, TOML). See
the wind ordinance schema for an example.
The configuration can also include the following optional keys:
- `data_type_short_desc`: Short description of the type of
data being extracted with this plugin, in the format
`wind energy ordinance`, `solar energy ordinance`,
`water rights`. This is used to enhance the prompts for
the structured data extraction.
- `query_templates`: A list of search engine query
templates for document retrieval. Templates should include
``{jurisdiction}`` as a placeholder for the jurisdiction
that is being processed. If not provided, the LLM will be
used to generate search engine queries based on the
schema input.
- `website_keywords`: A dictionary mapping keywords to
scores for filtering websites during document retrieval.
If not provided, the LLM will be used to generate
website keywords based on the schema input.
- `heuristic_keywords`: A dictionary containing the keyword
lists used by the heuristic document filter. The
dictionary must include ``not_tech_words``,
``good_tech_keywords``, ``good_tech_acronyms``, and
``good_tech_phrases`` keys. Alternatively, this input can
simply be ``True``, in which case the LLM will be used to
generate heuristic keyword lists based on the schema
input. If ``False``, ``None``, or not provided, a `NoOp`
heuristic that always returns ``True`` will be used (not
recommended if doing website crawling).
- `collection_prompts`: A list of prompts to use for
collecting relevant text from documents. Alternatively,
this input can simply be ``True``, in which case the LLM
will be used to generate the collection prompts. If
``False``, ``None``, or not provided, the entire document
text will be used for extraction (no text collection).
- `text_extraction_prompts`: A list of prompts to use for
consolidating and extracting relevant text from the
documents. Alternatively, this input can simply be
``True``, in which case the LLM will be used to generate
the text extraction prompts. If ``False``, ``None``, or
not provided, the entire document text will be used for
extraction (no text consolidation).
- `cache_llm_generated_content`: Boolean flag indicating
whether or not to cache generated query templates and
website keywords for future use. By default, ``True``.
Caching is recommended since the generation of query
templates and website keywords can be costly, but if you
are iterating on the configuration and want to see the
effect of changes to the schema on the generated query
templates and website keywords in real time, you may want
to set this flag to ``False`` to avoid caching generated
templates/keywords until you have finalized the schema.
- `extraction_system_prompt`: Custom system prompt to use
for the structured data extraction step. If not provided,
a default prompt will be used that instructs the LLM to
extract structured data from the given document(s). You
may provide a custom system prompt if you want to provide
more specific instructions to the LLM for the structured
data extraction step.
- `allow_multi_doc_extraction`: Boolean flag indicating
whether to allow multiple documents to be used for the
extraction context simultaneously. By default, ``False``,
which means the first document that returns some extracted
data will be marked as the source.
tech : str
Technology identifier to use for the plugin (e.g., "wind",
"solar"). Must be unique from the identifiers of any existing
plugins.
"""
if not isinstance(config, dict):
config = load_config(config)
if isinstance(config["schema"], str):
config["schema"] = load_config(config["schema"])
config["qual_feats"] = {
f.casefold() for f in config["schema"].pop("$qualitative_features", [])
}
text_collectors = _collectors_from_config(config)
text_extractors = _extractors_from_config(
config, in_label=text_collectors[-1].OUT_LABEL, tech=tech
)
parsers = _parser_from_config(
config, in_label=text_extractors[-1].OUT_LABEL
)
class SchemaBasedExtractionPlugin(OrdinanceExtractionPlugin):
SCHEMA = config["schema"]
"""dict: Schema for the output of the text extraction step"""
ALLOW_MULTI_DOC_EXTRACTION = config.get(
"allow_multi_doc_extraction", False
)
"""bool: Whether to allow extraction over multiple documents"""
IDENTIFIER = tech
"""str: Identifier for extraction task """
HEURISTIC = NoOpHeuristic
"""BaseHeuristic: Class with a ``check()`` method"""
HEURISTIC_KEYWORDS = None
"""dict: Keyword lists for heuristic content filtering"""
TEXT_COLLECTORS = text_collectors
"""Classes for collecting text chunks from docs"""
TEXT_EXTRACTORS = text_extractors
"""Classes for extracting cleaned text from collected text"""
PARSERS = parsers
"""Classes for parsing structured ordinance data from text"""
QUERY_TEMPLATES = [] # set by user or LLM-generated
"""list: List of search engine query templates"""
WEBSITE_KEYWORDS = {} # set by user or LLM-generated
"""dict: Keyword weight mapping for link crawl prioritization"""
async def get_heuristic(self):
"""Get a `BaseHeuristic` instance with a `check()` method
The ``check()`` method should accept a string of text and
return ``True`` if the text passes the heuristic check and
``False`` otherwise.
"""
if self.HEURISTIC_KEYWORDS and self.HEURISTIC is not NoOpHeuristic:
return self.HEURISTIC()
if not config.get("heuristic_keywords"):
return NoOpHeuristic()
hk = await self._get_heuristic_keywords()
class SchemaBasedHeuristic(KeywordBasedHeuristic):
NOT_TECH_WORDS = hk["NOT_TECH_WORDS"]
GOOD_TECH_KEYWORDS = hk["GOOD_TECH_KEYWORDS"]
GOOD_TECH_ACRONYMS = hk["GOOD_TECH_ACRONYMS"]
GOOD_TECH_PHRASES = hk["GOOD_TECH_PHRASES"]
self.__class__.HEURISTIC_KEYWORDS = hk
self.__class__.HEURISTIC = SchemaBasedHeuristic
return self.HEURISTIC()
async def get_query_templates(self):
"""Get a list of query templates for document retrieval
Returns
-------
list
List of search engine query templates for document
retrieval. Templates may include ``{jurisdiction}`` as
a placeholder for the jurisdiction that is being
processed.
"""
if self.QUERY_TEMPLATES:
return self.QUERY_TEMPLATES
if qt := config.get("query_templates"):
self.__class__.QUERY_TEMPLATES = qt
return qt
qt = key_from_cache(
self.IDENTIFIER,
config["schema"],
key=_CacheKey.QUERY_TEMPLATES,
)
if qt:
self.__class__.QUERY_TEMPLATES = qt
return qt
async with _QT_SEMAPHORE:
if self.QUERY_TEMPLATES:
return self.QUERY_TEMPLATES
model_config = self.model_configs.get(
LLMTasks.PLUGIN_GENERATION,
self.model_configs[LLMTasks.DEFAULT],
)
schema_llm = SchemaOutputLLMCaller(
llm_service=model_config.llm_service,
usage_tracker=self.usage_tracker,
**model_config.llm_call_kwargs,
)
logger.debug("Generating query templates...")
qt = await generate_query_templates(
schema_llm, config["schema"], add_think_prompt=True
)
logger.debug(
"Generated the following query templates:\n%r", qt
)
self.__class__.QUERY_TEMPLATES = qt
if config.get("cache_llm_generated_content", True):
key_to_cache(
self.IDENTIFIER,
config["schema"],
key=_CacheKey.QUERY_TEMPLATES,
value=qt,
)
return qt
async def get_website_keywords(self):
"""Get a dict of website search keyword scores
Returns
-------
dict
Dictionary mapping keywords to scores that indicate
links which should be prioritized when performing a
website scrape for a document.
"""
if self.WEBSITE_KEYWORDS:
return self.WEBSITE_KEYWORDS
if wk := config.get("website_keywords"):
wk = _augment_website_keywords(wk)
self.__class__.WEBSITE_KEYWORDS = wk
return wk
wk = key_from_cache(
self.IDENTIFIER,
config["schema"],
key=_CacheKey.WEBSITE_KEYWORDS,
)
if wk:
wk = _augment_website_keywords(wk)
self.__class__.WEBSITE_KEYWORDS = wk
return wk
async with _WK_SEMAPHORE:
if self.WEBSITE_KEYWORDS:
return self.WEBSITE_KEYWORDS
model_config = self.model_configs.get(
LLMTasks.PLUGIN_GENERATION,
self.model_configs[LLMTasks.DEFAULT],
)
schema_llm = SchemaOutputLLMCaller(
llm_service=model_config.llm_service,
usage_tracker=self.usage_tracker,
**model_config.llm_call_kwargs,
)
logger.debug("Generating website keywords...")
wk = await generate_website_keywords(
schema_llm,
config["schema"],
add_think_prompt=True,
)
logger.debug(
"Generated the following website keywords:\n%r", wk
)
if config.get("cache_llm_generated_content", True):
key_to_cache(
self.IDENTIFIER,
config["schema"],
key=_CacheKey.WEBSITE_KEYWORDS,
value=wk,
)
wk = _augment_website_keywords(wk)
self.__class__.WEBSITE_KEYWORDS = wk
return wk
async def _get_heuristic_keywords(self):
"""Get keyword lists for the heuristic document filter"""
if self.HEURISTIC_KEYWORDS:
return self.HEURISTIC_KEYWORDS
if isinstance(hk := config.get("heuristic_keywords"), dict):
hk = _normalize_heuristic_keywords(hk)
self.__class__.HEURISTIC_KEYWORDS = hk
return hk
hk = key_from_cache(
self.IDENTIFIER,
config["schema"],
key=_CacheKey.HEURISTIC_KEYWORDS,
)
if hk:
hk = _normalize_heuristic_keywords(hk)
self.__class__.HEURISTIC_KEYWORDS = hk
return hk
async with _HK_SEMAPHORE:
if self.HEURISTIC_KEYWORDS:
return self.HEURISTIC_KEYWORDS
model_config = self.model_configs.get(
LLMTasks.PLUGIN_GENERATION,
self.model_configs[LLMTasks.DEFAULT],
)
schema_llm = SchemaOutputLLMCaller(
llm_service=model_config.llm_service,
usage_tracker=self.usage_tracker,
**model_config.llm_call_kwargs,
)
logger.debug("Generating heuristic keywords...")
hk = await generate_heuristic_keywords(
schema_llm,
config["schema"],
add_think_prompt=True,
)
hk = _normalize_heuristic_keywords(hk)
logger.debug(
"Generated the following heuristic keywords:\n%r", hk
)
if config.get("cache_llm_generated_content", True):
key_to_cache(
self.IDENTIFIER,
config["schema"],
key=_CacheKey.HEURISTIC_KEYWORDS,
value=hk,
)
self.__class__.HEURISTIC_KEYWORDS = hk
return hk
def _validate_query_templates(self):
"""NoOp validation for query templates
Since templates can be generated by LLM, we don't know until
runtime whether or not they will be valid.
"""
def _validate_website_keywords(self):
"""NoOp validation for website keywords
Since keywords can be generated by LLM, we don't know until
runtime whether or not they will be valid.
"""
register_plugin(SchemaBasedExtractionPlugin)
def _collectors_from_config(config):
"""Create a TextCollector subclass based on a config dict"""
cp = config.get("collection_prompts")
if cp is True:
schema_fp = _SCHEMA_DIR / "validate_chunk.json5"
class PluginTextCollector(SchemaBasedTextCollector):
OUT_LABEL = NoOpTextCollector.OUT_LABEL # reuse label
SCHEMA = config["schema"]
OUTPUT_SCHEMA = load_config(schema_fp)
return [PluginTextCollector]
if cp:
class PluginTextCollector(PromptBasedTextCollector):
OUT_LABEL = NoOpTextCollector.OUT_LABEL # reuse label
PROMPTS = cp
return [PluginTextCollector]
return [NoOpTextCollector]
def _extractors_from_config(config, in_label, tech):
"""Create a TextExtractor subclass based on a config dict"""
tep = config.get("text_extraction_prompts")
if tep is True:
schema_fp = _SCHEMA_DIR / "extract_text.json5"
class PluginTextExtractor(SchemaBasedTextExtractor):
IN_LABEL = in_label
OUT_LABEL = "copied_relevant_text"
SCHEMA = config["schema"]
OUTPUT_SCHEMA = load_config(schema_fp)
CLEANED_FP_REGISTRY.setdefault(tech.casefold(), {})[
"copied_relevant_text"
] = "Text for Extraction.txt"
return [PluginTextExtractor]
if tep:
class PluginTextExtractor(PromptBasedTextExtractor):
IN_LABEL = in_label
PROMPTS = tep
return [PluginTextExtractor]
class PluginTextExtractor(NoOpTextExtractor):
IN_LABEL = in_label
OUT_LABEL = "copied_relevant_text"
return [PluginTextExtractor]
def _parser_from_config(config, in_label):
"""Create a TextExtractor subclass based on a config dict"""
new_sys_prompt = config.get(
"extraction_system_prompt", SchemaOrdinanceParser.SYSTEM_PROMPT
)
class PluginParser(SchemaOrdinanceParser):
IN_LABEL = in_label
OUT_LABEL = "structured_data"
SCHEMA = config["schema"]
QUALITATIVE_FEATURES = config["qual_feats"]
DATA_TYPE_SHORT_DESC = config.get("data_type_short_desc")
SYSTEM_PROMPT = new_sys_prompt
return [PluginParser]
def _augment_website_keywords(keywords):
"""Add URL-encoded variants for multi-word keywords"""
augmented = dict(keywords)
for keyword, score in list(augmented.items()):
if not isinstance(keyword, str):
continue
if " " not in keyword:
continue
encoded = keyword.replace(" ", "%20")
if encoded not in augmented:
augmented[encoded] = score
plus_encoded = keyword.replace(" ", "+")
if plus_encoded not in augmented:
augmented[plus_encoded] = score
return augmented
def _normalize_heuristic_keywords(raw):
"""Normalize heuristic keyword lists into required structure"""
if not isinstance(raw, dict):
msg = "Heuristic keywords must be a dictionary of keyword lists."
raise COMPASSPluginConfigurationError(msg)
expected_keys = {
"NOT_TECH_WORDS",
"GOOD_TECH_KEYWORDS",
"GOOD_TECH_ACRONYMS",
"GOOD_TECH_PHRASES",
}
normalized = {}
for raw_key, value in raw.items():
if not isinstance(raw_key, str):
msg = "Heuristic keyword keys must be strings."
raise COMPASSPluginConfigurationError(msg)
target_key = (
raw_key.strip().replace(" ", "_").replace("-", "_").upper()
)
if target_key not in expected_keys:
msg = f"Unexpected heuristic keyword list: {raw_key!r}."
raise COMPASSPluginConfigurationError(msg)
normalized[target_key] = _normalize_keyword_list(value)
missing = expected_keys - set(normalized)
if missing:
msg = (
f"Heuristic keywords are missing required lists: {sorted(missing)}"
)
raise COMPASSPluginConfigurationError(msg)
empty = [key for key, value in normalized.items() if not value]
if empty:
msg = f"Heuristic keyword lists must not be empty: {sorted(empty)}"
raise COMPASSPluginConfigurationError(msg)
return normalized
def _normalize_keyword_list(items):
"""Normalize keyword list entries"""
normalized = set()
for item in items:
if not isinstance(item, str):
continue
keyword = item.strip()
if not keyword:
continue
keyword = keyword.casefold()
if keyword in normalized:
continue
normalized.add(keyword)
return list(normalized)