diff --git a/bot_bottle/codex_auth.py b/bot_bottle/codex_auth.py index 9a0e6f9..9f6da0a 100644 --- a/bot_bottle/codex_auth.py +++ b/bot_bottle/codex_auth.py @@ -13,6 +13,7 @@ import os from copy import deepcopy from datetime import datetime, timezone from pathlib import Path +from typing import cast from .log import die from .util import expand_tilde @@ -50,7 +51,8 @@ def codex_host_access_token( tokens = raw.get("tokens") if not isinstance(tokens, dict): die(f"codex host credentials: {path} is missing tokens") - access = tokens.get("access_token") + tokens_typed = cast(dict[str, object], tokens) + access = tokens_typed.get("access_token") if not isinstance(access, str) or not access: die( f"codex host credentials: {path} is missing tokens.access_token. " @@ -105,14 +107,14 @@ def write_codex_dummy_auth_file( path.chmod(0o600) -def _read_auth_object(path: Path) -> dict: +def _read_auth_object(path: Path) -> dict[str, object]: try: raw = json.loads(path.read_text()) except (OSError, json.JSONDecodeError) as e: die(f"codex host credentials: could not read valid JSON at {path}: {e}") if not isinstance(raw, dict): die(f"codex host credentials: {path} must contain a JSON object") - return raw + return cast(dict[str, object], raw) def _dummy_exp(now: datetime | None, exp_ts: int | None) -> int: @@ -151,11 +153,11 @@ def _dummy_jwt_from_host( return _dummy_jwt(now, exp_ts=exp_ts) if not isinstance(payload, dict): return _dummy_jwt(now, exp_ts=exp_ts) - return _encode_dummy_jwt(_redact_jwt_payload(payload, now=now, exp_ts=exp_ts)) + return _encode_dummy_jwt(_redact_jwt_payload(cast(dict[str, object], payload), now=now, exp_ts=exp_ts)) -def _encode_dummy_jwt(payload: dict) -> str: - def enc(obj: dict) -> str: +def _encode_dummy_jwt(payload: dict[str, object]) -> str: + def enc(obj: dict[str, object]) -> str: raw = json.dumps(obj, separators=(",", ":")).encode() return base64.urlsafe_b64encode(raw).decode().rstrip("=") @@ -163,23 +165,24 @@ def _encode_dummy_jwt(payload: dict) -> str: def _redact_jwt_payload( - payload: dict, + payload: dict[str, object], *, now: datetime | None = None, exp_ts: int | None = None, -) -> dict: +) -> dict[str, object]: out = _redact_claims(payload) if not isinstance(out, dict): out = {} - out["exp"] = _dummy_exp(now, exp_ts) - out.setdefault("sub", "bot-bottle-placeholder") - return out + out_typed: dict[str, object] = cast(dict[str, object], out) + out_typed["exp"] = _dummy_exp(now, exp_ts) + out_typed.setdefault("sub", "bot-bottle-placeholder") + return out_typed def _redact_claims(value: object) -> object: if isinstance(value, dict): out: dict[str, object] = {} - for key, inner in value.items(): + for key, inner in cast(dict[str, object], value).items(): lower = key.lower() if key == "https://api.openai.com/profile": out[key] = _redact_profile_claim(inner) @@ -207,16 +210,16 @@ def _redact_claims(value: object) -> object: return "bot-bottle-placeholder" -def _redact_profile_claim(value: object) -> dict: - profile = value if isinstance(value, dict) else {} +def _redact_profile_claim(value: object) -> dict[str, object]: + profile = cast(dict[str, object], value) if isinstance(value, dict) else {} return { "email": "bot-bottle@example.invalid", "email_verified": bool(profile.get("email_verified", True)), } -def _redact_auth_claim(value: object) -> dict: - auth = value if isinstance(value, dict) else {} +def _redact_auth_claim(value: object) -> dict[str, object]: + auth = cast(dict[str, object], value) if isinstance(value, dict) else {} out: dict[str, object] = {} for key, inner in auth.items(): lower = key.lower() @@ -247,7 +250,7 @@ def _redact_auth_claim(value: object) -> dict: def _redact_codex_auth( value: object, *, now: datetime | None = None, exp_ts: int | None = None, ) -> object: - auth = value if isinstance(value, dict) else {} + auth = cast(dict[str, object], value) if isinstance(value, dict) else {} out: dict[str, object] = {} for key, inner in auth.items(): lower = key.lower() @@ -269,7 +272,7 @@ def _redact_codex_auth( def _redact_token_block( value: object, *, now: datetime | None = None, exp_ts: int | None = None, ) -> dict[str, object]: - tokens = value if isinstance(value, dict) else {} + tokens = cast(dict[str, object], value) if isinstance(value, dict) else {} out: dict[str, object] = {} for key, inner in tokens.items(): lower = key.lower() @@ -306,7 +309,7 @@ def _jwt_exp(token: str) -> datetime | None: return None if not isinstance(payload, dict): return None - exp = payload.get("exp") + exp = cast(dict[str, object], payload).get("exp") if not isinstance(exp, (int, float)): return None return datetime.fromtimestamp(exp, timezone.utc)