"""Ordinance document content Validation logic
These are primarily used to validate that a legal document applies to a
particular technology (e.g. Large Wind Energy Conversion Systems).
"""
import asyncio
import logging
from abc import ABC, abstractmethod
from warnings import warn
from compass.llm.calling import ChatLLMCaller, JSONFromTextLLMCaller
from compass.validation.graphs import setup_graph_correct_document_type
from compass.common import setup_async_decision_tree, run_async_tree
from compass.utilities.enums import LLMUsageCategory
from compass.utilities.ngrams import convert_text_to_sentence_ngrams
from compass.warn import COMPASSWarning
logger = logging.getLogger(__name__)
[docs]
class ParseChunksWithMemory:
"""Iterate through text chunks while caching prior LLM decisions
This helper stores an in-memory cache of prior validation results so
each chunk can optionally reuse outcomes from earlier LLM calls. The
design supports revisiting a configurable number of preceding text
chunks when newer chunks lack sufficient context.
"""
def __init__(self, text_chunks, num_to_recall=2):
"""
Parameters
----------
text_chunks : list of str
List of strings, each of which represent a chunk of text.
The order of the strings should be the order of the text
chunks. This validator may refer to previous text chunks to
answer validation questions.
num_to_recall : int, optional
Number of chunks to check for each validation call. This
includes the original chunk! For example, if
`num_to_recall=2`, the validator will first check the chunk
at the requested index, and then the previous chunk as well.
By default, ``2``.
"""
self.text_chunks = text_chunks
self.num_to_recall = num_to_recall
self.memory = [{} for _ in text_chunks]
def _inverted_mem(self, starting_ind):
"""Inverted memory"""
inverted_mem = self.memory[:starting_ind + 1:][::-1] # fmt: off
yield from inverted_mem[:self.num_to_recall] # fmt: off
def _inverted_text(self, starting_ind):
"""Inverted text chunks"""
inverted_text = self.text_chunks[:starting_ind + 1:][::-1] # fmt: off
yield from inverted_text[:self.num_to_recall] # fmt: off
[docs]
async def parse_from_ind(
self, ind, key, llm_call_callback, *args, **kwargs
):
"""Validate a chunk by consulting current and prior context
Cached verdicts are reused to avoid redundant LLM calls when
neighboring chunks have already been assessed. If the cache
lacks a verdict, the callback is executed and the result stored.
Parameters
----------
ind : int
Index of the chunk to inspect. Must be less than the number
of available chunks.
key : str
JSON key expected in the LLM response. The same key is used
to populate the decision cache.
llm_call_callback : callable
Awaitable invoked with
``await llm_call_callback(key, text_chunk)`` that returns a
boolean indicating whether the chunk satisfies the LLM
validation check.
Returns
-------
bool
``True`` if the selected or recalled chunk satisfies the
check, ``False`` otherwise.
"""
logger.debug("Checking %r for ind %d", key, ind)
mem_text = zip(
self._inverted_mem(ind), self._inverted_text(ind), strict=False
)
for step, (mem, text) in enumerate(mem_text):
logger.debug("Mem at ind %d is %s", step, mem)
check = mem.get(key)
if check is None:
check = mem[key] = await llm_call_callback(
key, text, *args, **kwargs
)
if check:
return check
return False
[docs]
class Heuristic(ABC):
"""Perform a heuristic check for mention of a technology in text"""
_GOOD_ACRONYM_CONTEXTS = [
" {acronym} ",
" {acronym}\n",
" {acronym}.",
"\n{acronym} ",
"\n{acronym}.",
"\n{acronym}\n",
"({acronym} ",
" {acronym})",
]
[docs]
def check(self, text, match_count_threshold=1):
"""Check for mention of a tech in text
This check first strips the text of any tech "look-alike" words
(e.g. "window", "windshield", etc for "wind" technology). Then,
it checks for particular keywords, acronyms, and phrases that
pertain to the tech in the text. If enough keywords are mentions
(as dictated by `match_count_threshold`), this check returns
``True``.
Parameters
----------
text : str
Input text that may or may not mention the technology of
interest.
match_count_threshold : int, optional
Number of keywords that must match for the text to pass this
heuristic check. Count must be strictly greater than this
value. By default, ``1``.
Returns
-------
bool
``True`` if the number of keywords/acronyms/phrases detected
exceeds the `match_count_threshold`.
"""
heuristics_text = self._convert_to_heuristics_text(text)
total_keyword_matches = self._count_single_keyword_matches(
heuristics_text
)
total_keyword_matches += self._count_acronym_matches(heuristics_text)
total_keyword_matches += self._count_phrase_matches(heuristics_text)
return total_keyword_matches > match_count_threshold
def _convert_to_heuristics_text(self, text):
"""Convert text for heuristic content parsing"""
heuristics_text = text.casefold()
for word in self.NOT_TECH_WORDS:
heuristics_text = heuristics_text.replace(word, "")
return heuristics_text
def _count_single_keyword_matches(self, heuristics_text):
"""Count number of good tech keywords that appear in text"""
return sum(
keyword in heuristics_text for keyword in self.GOOD_TECH_KEYWORDS
)
def _count_acronym_matches(self, heuristics_text):
"""Count number of good tech acronyms that appear in text"""
acronym_matches = 0
for context in self._GOOD_ACRONYM_CONTEXTS:
acronym_keywords = {
context.format(acronym=acronym)
for acronym in self.GOOD_TECH_ACRONYMS
}
acronym_matches = sum(
keyword in heuristics_text for keyword in acronym_keywords
)
if acronym_matches > 0:
break
return acronym_matches
def _count_phrase_matches(self, heuristics_text):
"""Count number of good tech phrases that appear in text"""
text_ngrams = {}
total = 0
for phrase in self.GOOD_TECH_PHRASES:
n = len(phrase.split(" "))
if n <= 1:
msg = (
"Make sure your GOOD_TECH_PHRASES contain at least 2 "
f"words! Got phrase: {phrase!r}"
)
warn(msg, COMPASSWarning)
continue
if n not in text_ngrams:
text_ngrams[n] = set(
convert_text_to_sentence_ngrams(heuristics_text, n)
)
test_ngrams = ( # fmt: off
convert_text_to_sentence_ngrams(phrase, n)
+ convert_text_to_sentence_ngrams(f"{phrase}s", n)
)
if any(t in text_ngrams[n] for t in test_ngrams):
total += 1
return total
@property
@abstractmethod
def NOT_TECH_WORDS(self): # noqa: N802
""":class:`~collections.abc.Iterable`: Not tech keywords"""
raise NotImplementedError
@property
@abstractmethod
def GOOD_TECH_KEYWORDS(self): # noqa: N802
""":class:`~collections.abc.Iterable`: Tech keywords"""
raise NotImplementedError
@property
@abstractmethod
def GOOD_TECH_ACRONYMS(self): # noqa: N802
""":class:`~collections.abc.Iterable`: Tech acronyms"""
raise NotImplementedError
@property
@abstractmethod
def GOOD_TECH_PHRASES(self): # noqa: N802
""":class:`~collections.abc.Iterable`: Tech phrases"""
raise NotImplementedError
[docs]
class TextKindValidator(ABC):
"""Base class for a text kind validator
This class is in charge of parsing text in chunks and ultimately
(after some X number of chunks have been parsed) determining if the
text is the right 'kind' of text for a given extraction.
"""
@property
@abstractmethod
def is_correct_kind_of_text(self):
"""bool: ``True`` if text is a good fit for extraction"""
raise NotImplementedError
[docs]
@abstractmethod
async def check_chunk(self, chunk_parser, ind):
"""Check a chunk to see if it contains the right kind of text
You should validate chunks like so::
is_correct_kind_of_text = await chunk_parser.parse_from_ind(
ind,
key="my_unique_validation_key",
llm_call_callback=my_async_llm_call_function,
)
where the `"key"` is unique to this particular validation (it
will be used to cache the validation result in the chunk
parser's memory) and `my_async_llm_call_function` is an async
function that takes in a key and text chunk and returns a
boolean indicating whether or not the text chunk passes the
validation. You can call `chunk_parser.parse_from_ind` as many
times as you want within this method, but be sure to use unique
keys for each validation.
Parameters
----------
chunk_parser : ParseChunksWithMemory
Instance that contains a ``parse_from_ind`` method.
ind : int
Index of the chunk to check.
Returns
-------
bool
Boolean flag indicating whether or not the text in the chunk
resembles legal text.
See Also
--------
ParseChunksWithMemory.parse_from_ind
Method used to parse text from a chunk with memory of prior
chunk validations.
"""
raise NotImplementedError
[docs]
class LegalTextValidator(TextKindValidator, JSONFromTextLLMCaller):
"""Parse chunks to determine if they contain legal text"""
SYSTEM_MESSAGE = (
"You are an AI designed to classify text excerpts based on their "
"source type. The goal is to identify text that is extracted from "
"**legally binding regulations (such as zoning ordinances or "
"enforceable bans)** and filter out text that was extracted from "
"anything other than a legal statute for an existing jurisdiction."
)
"""System message for legal text validation LLM calls"""
def __init__(
self, tech, *args, score_threshold=0.8, doc_is_from_ocr=False, **kwargs
):
"""
Parameters
----------
tech : str
Technology of interest (e.g. "solar", "wind", etc). This is
used to set up some document validation decision trees.
score_threshold : float, optional
Minimum fraction of text chunks that have to pass the legal
check for the whole document to be considered legal text.
By default, ``0.8``.
*args, **kwargs
Parameters to pass to the JSONFromTextLLMCaller initializer.
"""
super().__init__(*args, **kwargs)
self.tech = tech
self.score_threshold = score_threshold
self._legal_text_mem = []
self.doc_is_from_ocr = doc_is_from_ocr
@property
def is_correct_kind_of_text(self):
"""bool: ``True`` if text was found to be from a legal source"""
if not self._legal_text_mem:
return False
score = sum(self._legal_text_mem) / len(self._legal_text_mem)
return score >= self.score_threshold
[docs]
async def check_chunk(self, chunk_parser, ind):
"""Check a chunk at a given ind to see if it contains legal text
Parameters
----------
chunk_parser : ParseChunksWithMemory
Instance that contains a ``parse_from_ind`` method.
ind : int
Index of the chunk to check.
Returns
-------
bool
Boolean flag indicating whether or not the text in the chunk
resembles legal text.
"""
is_legal_text = await chunk_parser.parse_from_ind(
ind,
key="legal_text",
llm_call_callback=self._check_chunk_for_legal_text,
)
self._legal_text_mem.append(is_legal_text)
if is_legal_text:
logger.debug("Text at ind %d is legal text", ind)
else:
logger.debug("Text at ind %d is not legal text", ind)
return is_legal_text
async def _check_chunk_for_legal_text(self, key, text_chunk):
"""Call LLM on a chunk of text to check for legal text"""
chat_llm_caller = ChatLLMCaller(
llm_service=self.llm_service,
system_message=self.SYSTEM_MESSAGE.format(key=key),
usage_tracker=self.usage_tracker,
**self.kwargs,
)
tree = setup_async_decision_tree(
setup_graph_correct_document_type,
usage_sub_label=LLMUsageCategory.DOCUMENT_CONTENT_VALIDATION,
tech=self.tech,
key=key,
text=text_chunk,
chat_llm_caller=chat_llm_caller,
doc_is_from_ocr=self.doc_is_from_ocr,
)
out = await run_async_tree(tree, response_as_json=True)
logger.debug("LLM response: %s", out)
return out.get(key, False)
[docs]
async def parse_by_chunks(
chunk_parser,
heuristic,
text_kind_validator=None,
callbacks=None,
min_chunks_to_process=3,
):
"""Stream text chunks through heuristic and legal validators
This method goes through the chunks one by one, and passes them to
the callback parsers if the `text_kind_validator` check passes. If
`min_chunks_to_process` number of chunks fail the legal text check,
parsing is aborted.
Parameters
----------
chunk_parser : ParseChunksWithMemory
Instance that contains the attributes ``text_chunks`` and
``num_to_recall``. The chunks in the ``text_chunks`` attribute
will be iterated over.
heuristic : Heuristic
Instance of `Heuristic` with a `check` method. This should be a
fast check meant to quickly dispose of chunks of text. Any chunk
that fails this check will NOT be passed to the callback
parsers.
text_kind_validator : TextKindValidator, optional
Instance of `TextKindValidator` subclass that can be used to
validate whether each chunk of text is the kind needed for
extraction (e.g. is it legal text?). If not provided, the text
'kind' check will be skipped. By default, ``None``.
callbacks : list, optional
List of async callbacks that take a `chunk_parser` and `index`
as inputs and return a boolean determining whether the text
chunk was parsed successfully or not. By default, ``None``,
which does not use any callbacks.
min_chunks_to_process : int, optional
Minimum number of chunks to process before aborting due to text
not being legal. By default, ``3``.
Notes
-----
This coroutine only orchestrates validation. Callbacks are
responsible for persisting any extracted results. Callback futures
are awaited concurrently and share the same task name as the caller
to simplify tracing within structured logging.
"""
passed_heuristic_mem = []
callbacks = callbacks or []
outer_task_name = asyncio.current_task().get_name()
for ind, text in enumerate(chunk_parser.text_chunks):
passed_heuristic_mem.append(heuristic.check(text))
if ind < min_chunks_to_process:
if text_kind_validator is not None:
is_correct_text_kind = await text_kind_validator.check_chunk(
chunk_parser, ind
)
if not is_correct_text_kind:
continue # don't bother checking this chunk
elif (
text_kind_validator is not None
and not text_kind_validator.is_correct_kind_of_text
):
return # don't bother checking this document
# hasn't passed heuristic, so don't pass it to callbacks
elif not any(passed_heuristic_mem[-chunk_parser.num_to_recall :]):
continue
logger.debug("Processing text at ind %d", ind)
logger.debug_to_file("Text:\n%s", text)
if not callbacks:
continue
cb_futures = [
asyncio.create_task(cb(chunk_parser, ind), name=outer_task_name)
for cb in callbacks
]
cb_results = await asyncio.gather(*cb_futures)
# mask this chunk if we got a good result - this avoids forcing
# the following chunk to be checked (it will only be checked if
# it itself passes the heuristic)
passed_heuristic_mem[-1] = not any(cb_results)