diff --git a/bot_bottle/backend/docker/launch.py b/bot_bottle/backend/docker/launch.py index e811cbe..a1a6f5e 100644 --- a/bot_bottle/backend/docker/launch.py +++ b/bot_bottle/backend/docker/launch.py @@ -42,11 +42,7 @@ from contextlib import ExitStack, contextmanager from pathlib import Path from typing import Callable, Generator -from ...codex_auth import codex_host_access_token -from ...egress import ( - CODEX_HOST_CREDENTIAL_TOKEN_REF, - egress_resolve_token_values, -) +from ...egress import egress_resolve_token_values_with_provider from ...log import info from . import network as network_mod from . import util as docker_mod @@ -180,18 +176,12 @@ def launch( # Step 7: compose up. Token values + the OAuth placeholder # flow through subprocess env; the compose file holds only # bare names for the secret-carrying entries. - token_values: dict[str, str] = {} - if plan.egress_plan.routes: - token_values = egress_resolve_token_values( - plan.egress_plan.token_env_map, dict(os.environ), - ) - if plan.spec.manifest.bottle_for( - plan.spec.agent_name, - ).agent_provider.forward_host_credentials: - access_token = codex_host_access_token(dict(os.environ)) - for token_env, token_ref in plan.egress_plan.token_env_map.items(): - if token_ref == CODEX_HOST_CREDENTIAL_TOKEN_REF: - token_values[token_env] = access_token + bottle = plan.spec.manifest.bottle_for(plan.spec.agent_name) + token_values = egress_resolve_token_values_with_provider( + plan.egress_plan.token_env_map, + bottle.agent_provider.forward_host_credentials, + dict(os.environ), + ) compose_env: dict[str, str] = { **os.environ, **plan.forwarded_env, diff --git a/bot_bottle/backend/smolmachines/launch.py b/bot_bottle/backend/smolmachines/launch.py index da73dae..629eb2f 100644 --- a/bot_bottle/backend/smolmachines/launch.py +++ b/bot_bottle/backend/smolmachines/launch.py @@ -26,11 +26,9 @@ from contextlib import ExitStack, contextmanager from pathlib import Path from typing import Callable, Generator -from ...codex_auth import codex_host_access_token from ...egress import ( - CODEX_HOST_CREDENTIAL_TOKEN_REF, EGRESS_ROUTES_IN_CONTAINER, - egress_resolve_token_values, + egress_resolve_token_values_with_provider, ) from ...pipelock import ( PIPELOCK_CA_CERT_IN_CONTAINER, @@ -146,7 +144,7 @@ def launch( # spec's ports_to_publish list expands depending on which # daemons the agent needs to reach from the smolvm guest. bundle_spec = _bundle_launch_spec(plan, network, loopback_ip) - token_env = _resolve_token_env(plan, os.environ) + token_env = _resolve_token_env(plan, dict(os.environ)) _bundle.ensure_bundle_image(bundle_spec.image) _bundle.start_bundle(bundle_spec, env={**os.environ, **token_env}) stack.callback(_bundle.stop_bundle, plan.slug) @@ -420,24 +418,17 @@ def _bundle_launch_spec( def _resolve_token_env( - plan: SmolmachinesBottlePlan, host_env: object + plan: SmolmachinesBottlePlan, host_env: dict[str, str], ) -> dict[str, str]: """Resolve the egress token env-var values from the host's environ so they reach the bundle's process env via docker's `-e NAME` inheritance. Empty when no routes declare auth.""" - ep = plan.egress_plan - if not ep.routes: - return {} - env = dict(host_env) - token_values = egress_resolve_token_values(ep.token_env_map, env) - if plan.spec.manifest.bottle_for( - plan.spec.agent_name, - ).agent_provider.forward_host_credentials: - access_token = codex_host_access_token(env) - for token_env, token_ref in ep.token_env_map.items(): - if token_ref == CODEX_HOST_CREDENTIAL_TOKEN_REF: - token_values[token_env] = access_token - return token_values + bottle = plan.spec.manifest.bottle_for(plan.spec.agent_name) + return egress_resolve_token_values_with_provider( + plan.egress_plan.token_env_map, + bottle.agent_provider.forward_host_credentials, + host_env, + ) def _ensure_smolmachine(image_ref: str, *, dockerfile: str = "") -> Path: diff --git a/bot_bottle/egress.py b/bot_bottle/egress.py index 000d5e1..19988d4 100644 --- a/bot_bottle/egress.py +++ b/bot_bottle/egress.py @@ -29,6 +29,7 @@ from dataclasses import dataclass from pathlib import Path from typing import TYPE_CHECKING +from .codex_auth import codex_host_access_token from .log import die if TYPE_CHECKING: @@ -360,6 +361,31 @@ def egress_resolve_token_values( return out +def egress_resolve_token_values_with_provider( + token_env_map: dict[str, str], + forward_host_credentials: bool, + host_env: dict[str, str], +) -> dict[str, str]: + """Resolve all egress token env-var values, including the optional + Codex host credential slot. + + Combines `egress_resolve_token_values` (manifest-declared token refs) + with the `forward_host_credentials` path (Codex ChatGPT bearer). + Returns an empty dict when `token_env_map` is empty. + + Pure function: `host_env` is passed in so tests can use a sealed + mapping without touching `os.environ`.""" + if not token_env_map: + return {} + token_values = egress_resolve_token_values(token_env_map, host_env) + if forward_host_credentials: + access_token = codex_host_access_token(host_env) + for token_env, token_ref in token_env_map.items(): + if token_ref == CODEX_HOST_CREDENTIAL_TOKEN_REF: + token_values[token_env] = access_token + return token_values + + class Egress(ABC): """The per-bottle egress proxy. Encapsulates the host-side prepare (route lift + routes.yaml render + token-env-map derivation); the @@ -403,6 +429,7 @@ __all__ = [ "egress_manifest_routes", "egress_render_routes", "egress_resolve_token_values", + "egress_resolve_token_values_with_provider", "egress_routes_for_bottle", "egress_token_env_map", ] diff --git a/tests/unit/test_egress.py b/tests/unit/test_egress.py index 7d7107c..55b89c4 100644 --- a/tests/unit/test_egress.py +++ b/tests/unit/test_egress.py @@ -2,6 +2,7 @@ resolution (PRD 0017).""" import unittest +import unittest.mock from bot_bottle.egress import ( CODEX_HOST_CREDENTIAL_TOKEN_REF, @@ -9,6 +10,7 @@ from bot_bottle.egress import ( egress_manifest_routes, egress_render_routes, egress_resolve_token_values, + egress_resolve_token_values_with_provider, egress_routes_for_bottle, egress_token_env_map, ) @@ -349,5 +351,64 @@ class TestResolveTokenValues(unittest.TestCase): self.assertEqual({}, out) +class TestResolveTokenValuesWithProvider(unittest.TestCase): + def test_empty_map_returns_empty(self): + out = egress_resolve_token_values_with_provider({}, False, {}) + self.assertEqual({}, out) + + def test_empty_map_with_forward_credentials_returns_empty(self): + # forward_host_credentials=True but no slots → no codex call needed. + out = egress_resolve_token_values_with_provider({}, True, {}) + self.assertEqual({}, out) + + def test_manifest_tokens_resolved_without_forward_credentials(self): + out = egress_resolve_token_values_with_provider( + {"EGRESS_TOKEN_0": "GH_PAT"}, + False, + {"GH_PAT": "ghp_secret"}, + ) + self.assertEqual({"EGRESS_TOKEN_0": "ghp_secret"}, out) + + def test_codex_token_slotted_in_when_forward_credentials_and_matching_ref(self): + with unittest.mock.patch( + "bot_bottle.egress.codex_host_access_token", + return_value="codex-access-token", + ): + out = egress_resolve_token_values_with_provider( + {"EGRESS_TOKEN_0": CODEX_HOST_CREDENTIAL_TOKEN_REF}, + True, + {}, + ) + self.assertEqual({"EGRESS_TOKEN_0": "codex-access-token"}, out) + + def test_codex_token_not_slotted_when_no_matching_ref(self): + # forward_host_credentials=True but no CODEX_HOST_CREDENTIAL_TOKEN_REF + # slot in the map → manifest tokens only; Codex token is fetched but + # nothing to slot it into. + with unittest.mock.patch( + "bot_bottle.egress.codex_host_access_token", + return_value="codex-access-token", + ): + out = egress_resolve_token_values_with_provider( + {"EGRESS_TOKEN_0": "GH_PAT"}, + True, + {"GH_PAT": "ghp_secret"}, + ) + self.assertEqual({"EGRESS_TOKEN_0": "ghp_secret"}, out) + + def test_codex_not_called_when_forward_credentials_false(self): + called = [] + with unittest.mock.patch( + "bot_bottle.egress.codex_host_access_token", + side_effect=lambda *_: called.append(1) or "tok", + ): + egress_resolve_token_values_with_provider( + {"EGRESS_TOKEN_0": "GH_PAT"}, + False, + {"GH_PAT": "ghp_secret"}, + ) + self.assertEqual([], called) + + if __name__ == "__main__": unittest.main()