"""Ordinance document content Validation logic
These are primarily used to validate that a legal document applies to a
particular technology (e.g. Large Wind Energy Conversion Systems).
"""
import asyncio
import logging
from abc import ABC, abstractmethod
from warnings import warn
from compass.llm.calling import ChatLLMCaller, StructuredLLMCaller
from compass.validation.graphs import setup_graph_correct_document_type
from compass.common import setup_async_decision_tree, run_async_tree
from compass.utilities.enums import LLMUsageCategory
from compass.utilities.ngrams import convert_text_to_sentence_ngrams
from compass.warn import COMPASSWarning
logger = logging.getLogger(__name__)
[docs]
class ParseChunksWithMemory:
"""Iterate through text chunks while caching prior LLM decisions
This helper stores an in-memory cache of prior validation results so
each chunk can optionally reuse outcomes from earlier LLM calls. The
design supports revisiting a configurable number of preceding text
chunks when newer chunks lack sufficient context.
"""
def __init__(self, text_chunks, num_to_recall=2):
"""
Parameters
----------
text_chunks : list of str
List of strings, each of which represent a chunk of text.
The order of the strings should be the order of the text
chunks. This validator may refer to previous text chunks to
answer validation questions.
num_to_recall : int, optional
Number of chunks to check for each validation call. This
includes the original chunk! For example, if
`num_to_recall=2`, the validator will first check the chunk
at the requested index, and then the previous chunk as well.
By default, ``2``.
"""
self.text_chunks = text_chunks
self.num_to_recall = num_to_recall
self.memory = [{} for _ in text_chunks]
# fmt: off
def _inverted_mem(self, starting_ind):
"""Inverted memory"""
inverted_mem = self.memory[:starting_ind + 1:][::-1]
yield from inverted_mem[:self.num_to_recall]
# fmt: off
def _inverted_text(self, starting_ind):
"""Inverted text chunks"""
inverted_text = self.text_chunks[:starting_ind + 1:][::-1]
yield from inverted_text[:self.num_to_recall]
[docs]
async def parse_from_ind(self, ind, key, llm_call_callback):
"""Validate a chunk by consulting current and prior context
Cached verdicts are reused to avoid redundant LLM calls when
neighboring chunks have already been assessed. If the cache
lacks a verdict, the callback is executed and the result stored.
Parameters
----------
ind : int
Index of the chunk to inspect. Must be less than the number
of available chunks.
key : str
JSON key expected in the LLM response. The same key is used
to populate the decision cache.
llm_call_callback : callable
Awaitable invoked with ``(key, text_chunk)`` that returns a
boolean indicating whether the chunk satisfies the LLM
validation check.
Returns
-------
bool
``True`` if the selected or recalled chunk satisfies the
check, ``False`` otherwise.
"""
logger.debug("Checking %r for ind %d", key, ind)
mem_text = zip(
self._inverted_mem(ind), self._inverted_text(ind), strict=False
)
for step, (mem, text) in enumerate(mem_text):
logger.debug("Mem at ind %d is %s", step, mem)
check = mem.get(key)
if check is None:
check = mem[key] = await llm_call_callback(key, text)
if check:
return check
return False
[docs]
class Heuristic(ABC):
"""Perform a heuristic check for mention of a technology in text"""
_GOOD_ACRONYM_CONTEXTS = [
" {acronym} ",
" {acronym}\n",
" {acronym}.",
"\n{acronym} ",
"\n{acronym}.",
"\n{acronym}\n",
"({acronym} ",
" {acronym})",
]
[docs]
def check(self, text, match_count_threshold=1):
"""Check for mention of a tech in text
This check first strips the text of any tech "look-alike" words
(e.g. "window", "windshield", etc for "wind" technology). Then,
it checks for particular keywords, acronyms, and phrases that
pertain to the tech in the text. If enough keywords are mentions
(as dictated by `match_count_threshold`), this check returns
``True``.
Parameters
----------
text : str
Input text that may or may not mention the technology of
interest.
match_count_threshold : int, optional
Number of keywords that must match for the text to pass this
heuristic check. Count must be strictly greater than this
value. By default, ``1``.
Returns
-------
bool
``True`` if the number of keywords/acronyms/phrases detected
exceeds the `match_count_threshold`.
"""
heuristics_text = self._convert_to_heuristics_text(text)
total_keyword_matches = self._count_single_keyword_matches(
heuristics_text
)
total_keyword_matches += self._count_acronym_matches(heuristics_text)
total_keyword_matches += self._count_phrase_matches(heuristics_text)
return total_keyword_matches > match_count_threshold
def _convert_to_heuristics_text(self, text):
"""Convert text for heuristic content parsing"""
heuristics_text = text.casefold()
for word in self.NOT_TECH_WORDS:
heuristics_text = heuristics_text.replace(word, "")
return heuristics_text
def _count_single_keyword_matches(self, heuristics_text):
"""Count number of good tech keywords that appear in text"""
return sum(
keyword in heuristics_text for keyword in self.GOOD_TECH_KEYWORDS
)
def _count_acronym_matches(self, heuristics_text):
"""Count number of good tech acronyms that appear in text"""
acronym_matches = 0
for context in self._GOOD_ACRONYM_CONTEXTS:
acronym_keywords = {
context.format(acronym=acronym)
for acronym in self.GOOD_TECH_ACRONYMS
}
acronym_matches = sum(
keyword in heuristics_text for keyword in acronym_keywords
)
if acronym_matches > 0:
break
return acronym_matches
def _count_phrase_matches(self, heuristics_text):
"""Count number of good tech phrases that appear in text"""
text_ngrams = {}
total = 0
for phrase in self.GOOD_TECH_PHRASES:
n = len(phrase.split(" "))
if n <= 1:
msg = (
"Make sure your GOOD_TECH_PHRASES contain at least 2 "
f"words! Got phrase: {phrase!r}"
)
warn(msg, COMPASSWarning)
continue
if n not in text_ngrams:
text_ngrams[n] = set(
convert_text_to_sentence_ngrams(heuristics_text, n)
)
test_ngrams = ( # fmt: off
convert_text_to_sentence_ngrams(phrase, n)
+ convert_text_to_sentence_ngrams(f"{phrase}s", n)
)
if any(t in text_ngrams[n] for t in test_ngrams):
total += 1
return total
@property
@abstractmethod
def NOT_TECH_WORDS(self): # noqa: N802
""":class:`~collections.abc.Iterable`: Not tech keywords"""
raise NotImplementedError
@property
@abstractmethod
def GOOD_TECH_KEYWORDS(self): # noqa: N802
""":class:`~collections.abc.Iterable`: Tech keywords"""
raise NotImplementedError
@property
@abstractmethod
def GOOD_TECH_ACRONYMS(self): # noqa: N802
""":class:`~collections.abc.Iterable`: Tech acronyms"""
raise NotImplementedError
@property
@abstractmethod
def GOOD_TECH_PHRASES(self): # noqa: N802
""":class:`~collections.abc.Iterable`: Tech phrases"""
raise NotImplementedError
[docs]
class LegalTextValidator(StructuredLLMCaller):
"""Parse chunks to determine if they contain legal text"""
SYSTEM_MESSAGE = (
"You are an AI designed to classify text excerpts based on their "
"source type. The goal is to identify text that is extracted from "
"**legally binding regulations (such as zoning ordinances or "
"enforceable bans)** and filter out text that was extracted from "
"anything other than a legal statute for an existing jurisdiction."
)
"""System message for legal text validation LLM calls"""
def __init__(
self, tech, *args, score_threshold=0.8, doc_is_from_ocr=False, **kwargs
):
"""
Parameters
----------
tech : str
Technology of interest (e.g. "solar", "wind", etc). This is
used to set up some document validation decision trees.
score_threshold : float, optional
Minimum fraction of text chunks that have to pass the legal
check for the whole document to be considered legal text.
By default, ``0.8``.
*args, **kwargs
Parameters to pass to the StructuredLLMCaller initializer.
"""
super().__init__(*args, **kwargs)
self.tech = tech
self.score_threshold = score_threshold
self._legal_text_mem = []
self.doc_is_from_ocr = doc_is_from_ocr
@property
def is_legal_text(self):
"""bool: ``True`` if text was found to be from a legal source"""
if not self._legal_text_mem:
return False
score = sum(self._legal_text_mem) / len(self._legal_text_mem)
return score >= self.score_threshold
[docs]
async def check_chunk(self, chunk_parser, ind):
"""Check a chunk at a given ind to see if it contains legal text
Parameters
----------
chunk_parser : ParseChunksWithMemory
Instance that contains a ``parse_from_ind`` method.
ind : int
Index of the chunk to check.
Returns
-------
bool
Boolean flag indicating whether or not the text in the chunk
resembles legal text.
"""
is_legal_text = await chunk_parser.parse_from_ind(
ind,
key="legal_text",
llm_call_callback=self._check_chunk_for_legal_text,
)
self._legal_text_mem.append(is_legal_text)
if is_legal_text:
logger.debug("Text at ind %d is legal text", ind)
else:
logger.debug("Text at ind %d is not legal text", ind)
return is_legal_text
async def _check_chunk_for_legal_text(self, key, text_chunk):
"""Call LLM on a chunk of text to check for legal text"""
chat_llm_caller = ChatLLMCaller(
llm_service=self.llm_service,
system_message=self.SYSTEM_MESSAGE.format(key=key),
usage_tracker=self.usage_tracker,
**self.kwargs,
)
tree = setup_async_decision_tree(
setup_graph_correct_document_type,
usage_sub_label=LLMUsageCategory.DOCUMENT_CONTENT_VALIDATION,
tech=self.tech,
key=key,
text=text_chunk,
chat_llm_caller=chat_llm_caller,
doc_is_from_ocr=self.doc_is_from_ocr,
)
out = await run_async_tree(tree, response_as_json=True)
logger.debug("LLM response: %s", out)
return out.get(key, False)
[docs]
async def parse_by_chunks(
chunk_parser,
heuristic,
legal_text_validator=None,
callbacks=None,
min_chunks_to_process=3,
):
"""Stream text chunks through heuristic and legal validators
This method goes through the chunks one by one, and passes them to
the callback parsers if the `legal_text_validator` check passes. If
`min_chunks_to_process` number of chunks fail the legal text check,
parsing is aborted.
Parameters
----------
chunk_parser : ParseChunksWithMemory
Instance that contains the attributes ``text_chunks`` and
``num_to_recall``. The chunks in the ``text_chunks`` attribute
will be iterated over.
heuristic : Heuristic
Instance of `Heuristic` with a `check` method. This should be a
fast check meant to quickly dispose of chunks of text. Any chunk
that fails this check will NOT be passed to the callback
parsers.
legal_text_validator : LegalTextValidator, optional
Instance of `LegalTextValidator` that can be used to validate
each chunk for legal text. If not provided, the legal text check
will be skipped. By default, ``None``.
callbacks : list, optional
List of async callbacks that take a `chunk_parser` and `index`
as inputs and return a boolean determining whether the text
chunk was parsed successfully or not. By default, ``None``,
which does not use any callbacks.
min_chunks_to_process : int, optional
Minimum number of chunks to process before aborting due to text
not being legal. By default, ``3``.
Notes
-----
This coroutine only orchestrates validation. Callbacks are
responsible for persisting any extracted results. Callback futures
are awaited concurrently and share the same task name as the caller
to simplify tracing within structured logging.
"""
passed_heuristic_mem = []
callbacks = callbacks or []
outer_task_name = asyncio.current_task().get_name()
for ind, text in enumerate(chunk_parser.text_chunks):
passed_heuristic_mem.append(heuristic.check(text))
if ind < min_chunks_to_process:
if legal_text_validator is not None:
is_legal = await legal_text_validator.check_chunk(
chunk_parser, ind
)
if not is_legal: # don't bother checking this chunk
continue
# don't bother checking this document
elif (
legal_text_validator is not None
and not legal_text_validator.is_legal_text
):
return
# hasn't passed heuristic, so don't pass it to callbacks
elif not any(passed_heuristic_mem[-chunk_parser.num_to_recall :]):
continue
logger.debug("Processing text at ind %d", ind)
logger.debug_to_file("Text:\n%s", text)
if not callbacks:
continue
cb_futures = [
asyncio.create_task(cb(chunk_parser, ind), name=outer_task_name)
for cb in callbacks
]
cb_results = await asyncio.gather(*cb_futures)
# mask this chunk if we got a good result - this avoids forcing
# the following chunk to be checked (it will only be checked if
# it itself passes the heuristic)
passed_heuristic_mem[-1] = not any(cb_results)