fix(pi): select configured startup models
test / unit (pull_request) Successful in 30s
test / integration (pull_request) Successful in 17s
lint / lint (push) Successful in 1m37s
test / unit (push) Successful in 33s
test / integration (push) Successful in 17s
Update Quality Badges / update-badges (push) Successful in 1m6s

This commit was merged in pull request #222.
This commit is contained in:
2026-06-09 06:57:33 -04:00
parent 199edb228c
commit 86374ab293
9 changed files with 47 additions and 9 deletions
+1
View File
@@ -114,6 +114,7 @@ class AgentProvisionPlan:
prompt_file: Path
guest_env: dict[str, str]
has_prompt: bool = False
startup_args: tuple[str, ...] = ()
env_vars: dict[str, str] = field(default_factory=dict)
dirs: tuple[AgentProvisionDir, ...] = ()
files: tuple[AgentProvisionFile, ...] = ()
+2 -3
View File
@@ -23,6 +23,7 @@ class DockerBottle(Bottle):
*,
agent_command: str = "claude",
agent_prompt_mode: PromptMode = "append_file",
agent_provider_template: str = "claude",
terminal_title: str = "",
terminal_color: str = "",
):
@@ -33,9 +34,7 @@ class DockerBottle(Bottle):
self.agent_command = agent_command
self.terminal_title = terminal_title
self.terminal_color = terminal_color
self.agent_provider_template = (
"codex" if agent_command == "codex" else "claude"
)
self.agent_provider_template = agent_provider_template
self._closed = False
def agent_argv(
+1
View File
@@ -175,6 +175,7 @@ def launch(
None,
agent_command=plan.agent_command,
agent_prompt_mode=plan.agent_prompt_mode,
agent_provider_template=plan.agent_provider_template,
terminal_title=plan.spec.label or plan.spec.agent_name,
terminal_color=plan.spec.color,
)
+2 -3
View File
@@ -69,6 +69,7 @@ class SmolmachinesBottle(Bottle):
guest_env: Mapping[str, str] | None = None,
agent_command: str = "claude",
agent_prompt_mode: PromptMode = "append_file",
agent_provider_template: str = "claude",
terminal_title: str = "",
terminal_color: str = "",
) -> None:
@@ -86,9 +87,7 @@ class SmolmachinesBottle(Bottle):
self.agent_command = agent_command
self.terminal_title = terminal_title
self.terminal_color = terminal_color
self.agent_provider_template = (
"codex" if agent_command == "codex" else "claude"
)
self.agent_provider_template = agent_provider_template
def agent_argv(
self, argv: list[str], *, tty: bool = True,
@@ -103,6 +103,7 @@ def launch(
guest_env=plan.guest_env,
agent_command=plan.agent_command,
agent_prompt_mode=plan.agent_prompt_mode,
agent_provider_template=plan.agent_provider_template,
terminal_title=plan.spec.label or plan.spec.agent_name,
terminal_color=plan.spec.color,
)
+3
View File
@@ -133,6 +133,7 @@ def prepare_with_preflight(
def attach_agent(
bottle: Bottle, *, remote_control: bool = False, resume: bool = False,
agent_provider_template: str = "claude",
startup_args: tuple[str, ...] = (),
) -> int:
"""Run the selected provider CLI inside `bottle` as an
interactive session. Blocks until the session ends; returns the
@@ -151,6 +152,7 @@ def attach_agent(
agent_args = list(runtime.bypass_args)
if remote_control:
agent_args.extend(runtime.remote_control_args)
agent_args.extend(startup_args)
if resume:
agent_args.extend(runtime.resume_args)
return bottle.exec_agent(agent_args, tty=True)
@@ -235,6 +237,7 @@ def _launch_bottle(
bottle,
remote_control=remote_control,
agent_provider_template=agent_provider_template,
startup_args=plan.agent_provision.startup_args,
)
info(
f"session ended (exit {exit_code}); "
+9 -3
View File
@@ -58,7 +58,7 @@ def _settings_value(
def _pi_models_json(
settings: dict[str, object],
) -> tuple[dict[str, object], str, str]:
) -> tuple[dict[str, object], str, str, list[str], str]:
provider_name = str(
_settings_value(settings, "provider", _DEFAULT_PROVIDER_NAME)
)
@@ -94,7 +94,7 @@ def _pi_models_json(
provider_name: provider,
}
}
return payload, base_url, api_key_env
return payload, base_url, api_key_env, models, provider_name
def _route_host(base_url: str) -> str:
@@ -145,7 +145,9 @@ class PiAgentProvider(AgentProvider):
guest_home = self.guest_home
settings = dict(provider_settings or {})
models_payload, base_url, api_key_env = _pi_models_json(settings)
models_payload, base_url, api_key_env, models, provider_name = (
_pi_models_json(settings)
)
models_file = state_dir / "pi-models.json"
models_file.write_text(json.dumps(models_payload, indent=2) + "\n")
models_file.chmod(0o600)
@@ -163,6 +165,10 @@ class PiAgentProvider(AgentProvider):
prompt_file=prompt_file,
guest_env=resolved_guest_env,
has_prompt=has_prompt,
startup_args=(
"--models",
",".join(f"{provider_name}/{model}" for model in models),
),
dirs=(AgentProvisionDir(f"{guest_home}/.pi/agent"),),
files=(AgentProvisionFile(models_file, _models_path(guest_home)),),
egress_routes=(EgressRoute(
+5
View File
@@ -284,6 +284,7 @@ class TestAgentProviderRuntime(unittest.TestCase):
("/home/node/.pi/agent/models.json",),
tuple(f.guest_path for f in plan.files),
)
self.assertEqual(("--models", "ollama/qwen2.5-coder:7b"), plan.startup_args)
provider = models["providers"]["ollama"]
self.assertEqual("http://ollama:11434/v1", provider["baseUrl"])
self.assertEqual("openai-completions", provider["api"])
@@ -348,6 +349,10 @@ class TestAgentProviderRuntime(unittest.TestCase):
[{"id": "google/gemma-4-26b-a4b-it:free"}],
provider["models"],
)
self.assertEqual(
("--models", "openrouter/google/gemma-4-26b-a4b-it:free"),
plan.startup_args,
)
self.assertEqual("openrouter.ai", plan.egress_routes[0].host)
self.assertEqual("Bearer", plan.egress_routes[0].auth_scheme)
self.assertEqual("OPENROUTER_API_KEY", plan.egress_routes[0].token_ref)
+23
View File
@@ -80,5 +80,28 @@ class TestSettleState(_FakeHomeMixin, unittest.TestCase):
start_mod.settle_state("") # should not raise
class TestAttachAgent(unittest.TestCase):
def test_passes_provider_startup_args(self):
class Bottle:
argv: list[str] = []
def exec_agent(self, argv: list[str], *, tty: bool = True) -> int:
self.argv = list(argv)
return 0
bottle = Bottle()
exit_code = start_mod.attach_agent(
bottle, # type: ignore[arg-type]
agent_provider_template="pi",
startup_args=("--models", "openrouter/google/gemma"),
)
self.assertEqual(0, exit_code)
self.assertEqual(
["--models", "openrouter/google/gemma"],
bottle.argv,
)
if __name__ == "__main__":
unittest.main()