commit 45e3775e5c9e92756b649e045abf9af540e081e7
parent 1c54d4f169b6a9c7d425b9c3b72fce2b1f3bdb54
Author: triesap <tyson@radroots.org>
Date: Sun, 14 Jun 2026 16:20:45 -0700
provider: execute max local query rewrite
Add a repo-local MAX provider stub and stdio contract coverage for assisted query_rewrite execution. Record provider metadata, prompt/schema versions, latency, and assisted provenance for successful max_local responses.
Diffstat:
4 files changed, 347 insertions(+), 7 deletions(-)
diff --git a/src/hyf_provider/max_local.mojo b/src/hyf_provider/max_local.mojo
@@ -1,3 +1,5 @@
+from std.time import perf_counter_ns
+
from hyf_core.capabilities.query_analysis import QueryAnalysis
from hyf_core.request_context import RequestContext
from hyf_provider.client import (
@@ -32,10 +34,12 @@ def execute_query_rewrite_via_max_local_provider(
config: MaxLocalProviderConfig, text: String, context: RequestContext
) raises -> MaxLocalQueryRewriteResult:
with make_max_local_http_client(config) as client:
+ var start_ns = perf_counter_ns()
var response = client.post(
max_local_chat_completions_url(config),
build_query_rewrite_request_body(config, text, context),
)
+ var latency_ms = Int((perf_counter_ns() - start_ns) // 1_000_000)
if not response.ok():
raise Error(
"max_local provider returned HTTP "
@@ -49,7 +53,7 @@ def execute_query_rewrite_via_max_local_provider(
provider="max_local",
route=String(config.route),
model=String(config.model),
- latency_ms=0,
+ latency_ms=latency_ms,
schema_version=query_rewrite_schema_version(),
prompt_version=query_rewrite_prompt_version(),
)
diff --git a/tests/max_local_http_stub.py b/tests/max_local_http_stub.py
@@ -0,0 +1,112 @@
+import argparse
+import json
+import time
+from http.server import BaseHTTPRequestHandler, HTTPServer
+
+
+def query_rewrite_analysis():
+ return {
+ "original_text": "local apples pickup weekend",
+ "normalized_text": "local apples pickup weekend",
+ "rewritten_text": "apples pickup weekend",
+ "query_terms": ["apples", "pickup", "weekend"],
+ "normalization_signals": ["lowercase", "local_intent_detected"],
+ "ranking_hints": ["prefer_local_results", "prefer_pickup"],
+ "extracted_filters": {
+ "local_intent": True,
+ "fulfillment": "pickup",
+ "time_window": "weekend",
+ },
+ }
+
+
+def chat_completion(body):
+ return {
+ "choices": [
+ {
+ "message": {
+ "content": json.dumps(body, separators=(",", ":"))
+ }
+ }
+ ]
+ }
+
+
+class StubServer(HTTPServer):
+ def __init__(self, server_address, handler_class, mode, requests):
+ super().__init__(server_address, handler_class)
+ self.mode = mode
+ self.requests_remaining = requests
+
+
+class Handler(BaseHTTPRequestHandler):
+ def log_message(self, format, *args):
+ return
+
+ def _send(self, status, payload, content_type="application/json"):
+ if isinstance(payload, str):
+ body = payload.encode("utf-8")
+ else:
+ body = json.dumps(payload, separators=(",", ":")).encode("utf-8")
+ self.send_response(status)
+ self.send_header("content-type", content_type)
+ self.send_header("content-length", str(len(body)))
+ self.end_headers()
+ self.wfile.write(body)
+
+ def do_GET(self):
+ if self.path == "/health":
+ if self.server.mode == "health_non_2xx":
+ self._send(503, {"status": "unavailable"})
+ else:
+ self._send(200, {"status": "ok"})
+ return
+ self._send(404, {"error": "not_found"})
+
+ def do_POST(self):
+ if self.path != "/v1/chat/completions":
+ self._send(404, {"error": "not_found"})
+ return
+ mode = self.server.mode
+ if mode == "query_rewrite_ok":
+ self._send(200, chat_completion(query_rewrite_analysis()))
+ elif mode == "query_rewrite_non_2xx":
+ self._send(503, {"error": {"message": "provider unavailable"}})
+ elif mode == "query_rewrite_invalid_json":
+ self._send(200, '{"choices":[{"message":{"content":"not json"}}]}')
+ elif mode == "query_rewrite_schema_invalid":
+ body = query_rewrite_analysis()
+ del body["rewritten_text"]
+ self._send(200, chat_completion(body))
+ elif mode == "query_rewrite_empty_choices":
+ self._send(200, {"choices": []})
+ elif mode == "query_rewrite_missing_content":
+ self._send(200, {"choices": [{"message": {}}]})
+ elif mode == "query_rewrite_error_payload":
+ self._send(200, {"error": {"message": "provider refusal"}})
+ elif mode == "query_rewrite_timeout":
+ time.sleep(2)
+ self._send(200, chat_completion(query_rewrite_analysis()))
+ else:
+ self._send(500, {"error": "unsupported_mode"})
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--port", type=int, required=True)
+ parser.add_argument("--mode", required=True)
+ parser.add_argument("--requests", type=int, default=2)
+ args = parser.parse_args()
+
+ server = StubServer(
+ ("127.0.0.1", args.port), Handler, args.mode, args.requests
+ )
+ print("ready", flush=True)
+ while server.requests_remaining > 0:
+ server.handle_request()
+ server.requests_remaining -= 1
+ server.server_close()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/tests/max_local_process_helper.mojo b/tests/max_local_process_helper.mojo
@@ -0,0 +1,123 @@
+from std.collections import List, Optional
+from std.ffi import CStringSlice, c_int, external_call
+from std.os import Pipe, Process
+from std.sys._libc import close
+
+from flare.net import SocketAddr
+from flare.tcp import TcpListener
+
+
+def _dup2(oldfd: c_int, newfd: c_int) -> c_int:
+ return external_call["dup2", c_int](oldfd, newfd)
+
+
+@always_inline
+def _fork() -> c_int:
+ return external_call["fork", c_int]()
+
+
+@always_inline
+def _exit_child(code: c_int):
+ _ = external_call["_exit", c_int](code)
+
+
+def _read_pipe_line(mut pipe: Pipe) raises -> String:
+ var buffer = InlineArray[Byte, 1](fill=0)
+ var output = String("")
+ while True:
+ var read = pipe.read_bytes(Span(buffer))
+ if read == 0:
+ break
+ var chunk = String(
+ from_utf8=Span(ptr=buffer.unsafe_ptr(), length=Int(read))
+ )
+ if chunk == "\n":
+ break
+ output += chunk
+ return output^
+
+
+struct SpawnedMaxLocalStub(Movable):
+ var pid: Int
+
+ def __init__(out self, pid: Int):
+ self.pid = pid
+
+ def wait(mut self) raises:
+ var process = Process(self.pid)
+ var status = process.wait()
+ if not status.exit_code or status.exit_code.value() != 0:
+ raise Error("max_local stub exited unexpectedly")
+
+
+def reserve_loopback_port() raises -> Int:
+ var listener = TcpListener.bind(SocketAddr.localhost(0))
+ var port = Int(listener.local_addr().port)
+ listener.close()
+ return port
+
+
+def spawn_max_local_stub(
+ port: Int, mode: String, requests: Int
+) raises -> SpawnedMaxLocalStub:
+ var stdout_pipe = Pipe()
+ var command = String("python")
+ var entrypoint = String("tests/max_local_http_stub.py")
+ var arg_port_flag = String("--port")
+ var arg_port = String(port)
+ var arg_mode_flag = String("--mode")
+ var arg_mode = String(mode)
+ var arg_requests_flag = String("--requests")
+ var arg_requests = String(requests)
+ var argv = List[Optional[CStringSlice[ImmutAnyOrigin]]](length=9, fill={})
+ argv[0] = rebind[CStringSlice[ImmutAnyOrigin]](
+ command.as_c_string_slice()
+ )
+ argv[1] = rebind[CStringSlice[ImmutAnyOrigin]](
+ entrypoint.as_c_string_slice()
+ )
+ argv[2] = rebind[CStringSlice[ImmutAnyOrigin]](
+ arg_port_flag.as_c_string_slice()
+ )
+ argv[3] = rebind[CStringSlice[ImmutAnyOrigin]](
+ arg_port.as_c_string_slice()
+ )
+ argv[4] = rebind[CStringSlice[ImmutAnyOrigin]](
+ arg_mode_flag.as_c_string_slice()
+ )
+ argv[5] = rebind[CStringSlice[ImmutAnyOrigin]](
+ arg_mode.as_c_string_slice()
+ )
+ argv[6] = rebind[CStringSlice[ImmutAnyOrigin]](
+ arg_requests_flag.as_c_string_slice()
+ )
+ argv[7] = rebind[CStringSlice[ImmutAnyOrigin]](
+ arg_requests.as_c_string_slice()
+ )
+ var stdout_read_fd = c_int(stdout_pipe.fd_in.value().value)
+ var stdout_write_fd = c_int(stdout_pipe.fd_out.value().value)
+ var command_ptr = command.as_c_string_slice().unsafe_ptr()
+ var argv_ptr = argv.unsafe_ptr()
+
+ var pid = _fork()
+ if pid < 0:
+ raise Error("failed to spawn max_local stub")
+
+ if pid == 0:
+ if _dup2(stdout_write_fd, 1) < 0:
+ _exit_child(c_int(126))
+ _ = close(stdout_read_fd)
+ _ = close(stdout_write_fd)
+ _ = external_call["execvp", c_int](command_ptr, argv_ptr)
+ _exit_child(c_int(127))
+
+ stdout_pipe.set_input_only()
+ var ready_line = _read_pipe_line(stdout_pipe)
+ if ready_line != "ready":
+ stdout_pipe.set_output_only()
+ var process = Process(Int(pid))
+ _ = process.wait()
+ raise Error("max_local stub failed to report ready")
+
+ stdout_pipe.set_output_only()
+ return SpawnedMaxLocalStub(Int(pid))
diff --git a/tests/test_stdio_contract.mojo b/tests/test_stdio_contract.mojo
@@ -10,6 +10,10 @@ from fixture_assertions import (
load_scenario_request_json,
status_request_with_invalid_version_json,
)
+from max_local_process_helper import (
+ reserve_loopback_port,
+ spawn_max_local_stub,
+)
from stdio_process_helper import (
HYF_PATHS_PROFILE_ENV,
HYF_PATHS_REPO_LOCAL_ROOT_ENV,
@@ -39,17 +43,33 @@ def _array_contains_string(value: Value, expected: String) raises -> Bool:
return False
-def _max_local_runtime_config_toml() -> String:
+def _max_local_runtime_config_toml_with_urls(
+ base_url: String, health_url: String, request_timeout_ms: Int
+) -> String:
return (
'[service]\ntransport = "stdio"\n\n'
'[runtime]\ndefault_execution_mode = "deterministic"\nallow_assisted = true\n\n'
'[assisted]\nprovider = "max_local"\n\n'
'[assisted.max_local]\nenabled = true\n'
- 'base_url = "http://127.0.0.1:8000/v1"\n'
- 'health_url = "http://127.0.0.1:8000/health"\n'
- 'model = "max-local-query-rewrite"\n'
- 'route = "provider_runtime.query_rewrite.max_local"\n'
- 'request_timeout_ms = 15000\n'
+ 'base_url = "'
+ + base_url
+ + '"\n'
+ + 'health_url = "'
+ + health_url
+ + '"\n'
+ + 'model = "max-local-query-rewrite"\n'
+ + 'route = "provider_runtime.query_rewrite.max_local"\n'
+ + 'request_timeout_ms = '
+ + String(request_timeout_ms)
+ + "\n"
+ )
+
+
+def _max_local_runtime_config_toml() -> String:
+ return _max_local_runtime_config_toml_with_urls(
+ "http://127.0.0.1:8000/v1",
+ "http://127.0.0.1:8000/health",
+ 15000,
)
@@ -692,6 +712,87 @@ def test_query_rewrite_falls_back_deterministically_when_provider_is_unavailable
)
+def test_query_rewrite_uses_max_local_provider_when_ready() raises:
+ with TemporaryDirectory() as temp_dir:
+ var provider_port = reserve_loopback_port()
+ var provider_stub = spawn_max_local_stub(
+ provider_port, "query_rewrite_ok", 2
+ )
+ var startup_config_path = Path(temp_dir) / "explicit-hyf-config.toml"
+ startup_config_path.write_text(
+ _max_local_runtime_config_toml_with_urls(
+ "http://127.0.0.1:" + String(provider_port) + "/v1",
+ "http://127.0.0.1:" + String(provider_port) + "/health",
+ 15000,
+ )
+ )
+ with ScopedEnvVar(HYF_PATHS_PROFILE_ENV, "repo_local"):
+ with ScopedEnvVar(HYF_PATHS_REPO_LOCAL_ROOT_ENV, temp_dir):
+ var response = run_stdio_entrypoint(
+ "src/main.mojo",
+ '{"version":1,"request_id":"rewrite-assisted-max-local-1","trace_id":"rewrite-assisted-max-local-1","capability":"query_rewrite","context":{"execution_mode_preference":"assisted","return_provenance":true},"input":{"query":"local apples pickup weekend"}}',
+ "--config",
+ startup_config_path.__fspath__(),
+ )
+
+ assert_true(response["ok"].bool_value())
+ assert_equal(
+ response["meta"]["execution_mode"].string_value(),
+ "assisted",
+ )
+ assert_equal(
+ response["meta"]["backend"].string_value(),
+ "provider_runtime",
+ )
+ assert_equal(
+ response["meta"]["provider"].string_value(),
+ "max_local",
+ )
+ assert_equal(
+ response["meta"]["route"].string_value(),
+ "provider_runtime.query_rewrite.max_local",
+ )
+ assert_equal(
+ response["meta"]["model"].string_value(),
+ "max-local-query-rewrite",
+ )
+ assert_true(
+ Int(response["meta"]["latency_ms"].int_value()) >= 0
+ )
+ assert_equal(
+ Int(response["meta"]["schema_version"].int_value()), 1
+ )
+ assert_equal(
+ response["meta"]["prompt_version"].string_value(),
+ "max_local_query_rewrite_v1",
+ )
+ assert_equal(
+ response["meta"]["provenance"]["kind"].string_value(),
+ "assisted",
+ )
+ assert_true(
+ response["meta"]["provenance"]["fallback"].is_null()
+ )
+ assert_equal(
+ response["output"]["rewritten_text"].string_value(),
+ "apples pickup weekend",
+ )
+ assert_equal(
+ response["output"]["query_terms"][0].string_value(),
+ "apples",
+ )
+ assert_equal(
+ response["output"]["query_terms"][1].string_value(),
+ "pickup",
+ )
+ assert_equal(
+ response["output"]["query_terms"][2].string_value(),
+ "weekend",
+ )
+
+ provider_stub.wait()
+
+
def test_status_reports_configured_but_deferred_custody_truthfully() raises:
with TemporaryDirectory() as temp_dir:
var identity_dir = Path(temp_dir) / "secrets" / "services" / "hyf"