perf(dlp): memoize encoded variants and linearize partial-window scan

Two per-request hot-path costs in the egress DLP scanner:

- `_encoded_variants` derived the full variant set (gzip + nine
  encodings) for every provisioned secret on every redaction and
  known-secret scan — once per host, path, header, and body. Cache it
  per distinct secret; callers still get a fresh list so they can't
  corrupt the shared cached tuple.
- `_find_partial_window` searched the text once per secret n-gram,
  giving O(len(secret) * len(text)). Build the secret's n-gram set once
  and sweep the text a single time: O(len(text)), no coverage loss.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Claude-Session: https://claude.ai/code/session_01NkwFXLFff9PYPy4wgVBJp9
This commit is contained in:
2026-06-26 22:53:27 -04:00
parent ebbcae663c
commit 0bb47bd754
2 changed files with 49 additions and 10 deletions
+38 -10
View File
@@ -126,8 +126,30 @@ def redact_tokens(
# Known secrets detector
# ---------------------------------------------------------------------------
# Encoded-variant cache. Provisioned secrets are stable for the life of the
# proxy, but `_encoded_variants` is on the per-request hot path — it runs for
# every secret on every redaction and known-secret scan (host, path, each
# header, body). Deriving the variant set is relatively expensive (gzip +
# nine encodings), so memoize it per distinct secret. The proxy process
# already holds these values in `os.environ`, so caching them here adds no
# new exposure.
_VARIANT_CACHE: dict[str, tuple[str, ...]] = {}
def _encoded_variants(secret: str) -> list[str]:
"""Return the secret plus common encoded variants for exfil detection."""
"""Return the secret plus common encoded variants for exfil detection.
The variant set is computed once per distinct secret and cached; callers
get a fresh list so they can't mutate the shared cached tuple."""
cached = _VARIANT_CACHE.get(secret)
if cached is None:
cached = _compute_encoded_variants(secret)
_VARIANT_CACHE[secret] = cached
return list(cached)
def _compute_encoded_variants(secret: str) -> tuple[str, ...]:
"""Derive the secret plus its encoded variants (uncached)."""
seen: set[str] = {secret}
variants: list[str] = [secret]
@@ -161,7 +183,7 @@ def _encoded_variants(secret: str) -> list[str]:
# gzip + base64 (deterministic: mtime=0); recognisable by H4sI prefix
_add(base64.b64encode(gzip.compress(secret_bytes, mtime=0)).decode("ascii"))
return variants
return tuple(variants)
# ---------------------------------------------------------------------------
@@ -187,18 +209,24 @@ def _alnum_projection(text: str) -> str:
def _find_partial_window(secret_alnum: str, text_alnum: str, min_len: int) -> int | None:
"""Return the position in text_alnum where any min_len-char window of
secret_alnum first appears, or None.
"""Return the earliest position in text_alnum holding a min_len-char window
that also appears in secret_alnum, or None.
Slides a window of width min_len across secret_alnum and searches for
each window in text_alnum. The first hit position is returned.
The secret's set of min_len-grams is small (bounded by the secret length),
so building it once and sweeping the text a single time is O(len(text))
rather than the O(len(secret) * len(text)) of repeated substring searches —
which matters because this runs per provisioned secret on every request
body. Coverage is unchanged: a hit still means at least min_len consecutive
alphanumeric characters of the secret leaked into the text.
"""
if len(secret_alnum) < min_len or len(text_alnum) < min_len:
return None
for i in range(len(secret_alnum) - min_len + 1):
window = secret_alnum[i:i + min_len]
pos = text_alnum.find(window)
if pos >= 0:
secret_grams = {
secret_alnum[i:i + min_len]
for i in range(len(secret_alnum) - min_len + 1)
}
for pos in range(len(text_alnum) - min_len + 1):
if text_alnum[pos:pos + min_len] in secret_grams:
return pos
return None
+11
View File
@@ -281,6 +281,17 @@ class TestEncodedVariants(unittest.TestCase):
v = self._variants()
self.assertEqual(len(v), len(set(v)))
def test_repeated_calls_equal(self):
# Memoization must not change observable output.
self.assertEqual(self._variants(), self._variants())
def test_returns_fresh_list_each_call(self):
# Callers mutate/iterate the result; the cached set must not be
# exposed by reference, or one caller could corrupt another's view.
first = self._variants()
first.append("MUTATED")
self.assertNotIn("MUTATED", self._variants())
class TestUnicodeNormalization(unittest.TestCase):
def test_fullwidth_chars_normalized(self):