commit f22f74cd33d5a8735204a650603c2cc16f0056b7
parent f5475c71335c7aa8d3f69525b60ac7848c639ce7
Author: triesap <tyson@radroots.org>
Date: Mon, 15 Jun 2026 13:45:06 -0700
tests: replace max local stub with Mojo
- remove the Python MAX-local provider stub
- serve provider test responses from a forked Mojo helper
- preserve provider success and failure modes without Python helpers
Diffstat:
2 files changed, 149 insertions(+), 154 deletions(-)
diff --git a/tests/max_local_http_stub.py b/tests/max_local_http_stub.py
@@ -1,115 +0,0 @@
-import argparse
-import json
-import time
-from http.server import BaseHTTPRequestHandler, HTTPServer
-
-
-def query_rewrite_analysis():
- return {
- "original_text": "local apples pickup weekend",
- "normalized_text": "local apples pickup weekend",
- "rewritten_text": "apples pickup weekend",
- "query_terms": ["apples", "pickup", "weekend"],
- "normalization_signals": ["lowercase", "local_intent_detected"],
- "ranking_hints": ["prefer_local_results", "prefer_pickup"],
- "extracted_filters": {
- "local_intent": True,
- "fulfillment": "pickup",
- "time_window": "weekend",
- },
- }
-
-
-def chat_completion(body):
- return {
- "choices": [
- {
- "message": {
- "content": json.dumps(body, separators=(",", ":"))
- }
- }
- ]
- }
-
-
-class StubServer(HTTPServer):
- def __init__(self, server_address, handler_class, mode, requests):
- super().__init__(server_address, handler_class)
- self.mode = mode
- self.requests_remaining = requests
-
-
-class Handler(BaseHTTPRequestHandler):
- def log_message(self, format, *args):
- return
-
- def _send(self, status, payload, content_type="application/json"):
- if isinstance(payload, str):
- body = payload.encode("utf-8")
- else:
- body = json.dumps(payload, separators=(",", ":")).encode("utf-8")
- self.send_response(status)
- self.send_header("content-type", content_type)
- self.send_header("content-length", str(len(body)))
- self.end_headers()
- try:
- self.wfile.write(body)
- except (BrokenPipeError, ConnectionResetError):
- return
-
- def do_GET(self):
- if self.path == "/health":
- if self.server.mode == "health_non_2xx":
- self._send(503, {"status": "unavailable"})
- else:
- self._send(200, {"status": "ok"})
- return
- self._send(404, {"error": "not_found"})
-
- def do_POST(self):
- if self.path != "/v1/chat/completions":
- self._send(404, {"error": "not_found"})
- return
- mode = self.server.mode
- if mode == "query_rewrite_ok":
- self._send(200, chat_completion(query_rewrite_analysis()))
- elif mode == "query_rewrite_non_2xx":
- self._send(503, {"error": {"message": "provider unavailable"}})
- elif mode == "query_rewrite_invalid_json":
- self._send(200, '{"choices":[{"message":{"content":"not json"}}]}')
- elif mode == "query_rewrite_schema_invalid":
- body = query_rewrite_analysis()
- del body["rewritten_text"]
- self._send(200, chat_completion(body))
- elif mode == "query_rewrite_empty_choices":
- self._send(200, {"choices": []})
- elif mode == "query_rewrite_missing_content":
- self._send(200, {"choices": [{"message": {}}]})
- elif mode == "query_rewrite_error_payload":
- self._send(200, {"error": {"message": "provider refusal"}})
- elif mode == "query_rewrite_timeout":
- time.sleep(2)
- self._send(200, chat_completion(query_rewrite_analysis()))
- else:
- self._send(500, {"error": "unsupported_mode"})
-
-
-def main():
- parser = argparse.ArgumentParser()
- parser.add_argument("--port", type=int, required=True)
- parser.add_argument("--mode", required=True)
- parser.add_argument("--requests", type=int, default=2)
- args = parser.parse_args()
-
- server = StubServer(
- ("127.0.0.1", args.port), Handler, args.mode, args.requests
- )
- print("ready", flush=True)
- while server.requests_remaining > 0:
- server.handle_request()
- server.requests_remaining -= 1
- server.server_close()
-
-
-if __name__ == "__main__":
- main()
diff --git a/tests/max_local_process_helper.mojo b/tests/max_local_process_helper.mojo
@@ -1,10 +1,11 @@
-from std.collections import List, Optional
-from std.ffi import CStringSlice, c_int, external_call
+from std.ffi import c_int, c_size_t, c_ssize_t, external_call
from std.os import Pipe, Process
from std.sys._libc import close
from flare.net import SocketAddr
from flare.tcp import TcpListener
+from flare.tcp import TcpStream
+from flare.utils import usleep
def _dup2(oldfd: c_int, newfd: c_int) -> c_int:
@@ -21,6 +22,12 @@ def _exit_child(code: c_int):
_ = external_call["_exit", c_int](code)
+def _write(fd: Int, text: String):
+ _ = external_call["write", c_ssize_t](
+ fd, text.as_bytes().unsafe_ptr(), c_size_t(text.byte_length())
+ )
+
+
def _read_pipe_line(mut pipe: Pipe) raises -> String:
var buffer = InlineArray[Byte, 1](fill=0)
var output = String("")
@@ -37,6 +44,141 @@ def _read_pipe_line(mut pipe: Pipe) raises -> String:
return output^
+def _read_request(mut stream: TcpStream) raises -> String:
+ var buffer = List[UInt8]()
+ buffer.resize(8192, 0)
+ var n = stream.read(buffer.unsafe_ptr(), len(buffer))
+ if n <= 0:
+ return ""
+ return String(unsafe_from_utf8=buffer[:n])
+
+
+def _request_path(request: String) -> String:
+ var line_end = request.find("\r\n")
+ if line_end < 0:
+ return ""
+ var first_line = String(request[byte=0:line_end])
+ var first_space = first_line.find(" ")
+ if first_space < 0:
+ return ""
+ var rest = String(first_line[byte=first_space + 1:])
+ var second_space = rest.find(" ")
+ if second_space < 0:
+ return ""
+ return String(rest[byte=0:second_space])
+
+
+def _json_string(value: String) -> String:
+ return '"' + value.replace("\\", "\\\\").replace('"', '\\"') + '"'
+
+
+def _query_rewrite_analysis() -> String:
+ return (
+ '{"original_text":"local apples pickup weekend",'
+ '"normalized_text":"local apples pickup weekend",'
+ '"rewritten_text":"apples pickup weekend",'
+ '"query_terms":["apples","pickup","weekend"],'
+ '"normalization_signals":["lowercase","local_intent_detected"],'
+ '"ranking_hints":["prefer_local_results","prefer_pickup"],'
+ '"extracted_filters":{'
+ '"local_intent":true,'
+ '"fulfillment":"pickup",'
+ '"time_window":"weekend"'
+ "}}"
+ )
+
+
+def _chat_completion(body: String) -> String:
+ return '{"choices":[{"message":{"content":' + _json_string(body) + "}}]}"
+
+
+def _response(status: Int, body: String) -> String:
+ var reason = "OK"
+ if status == 404:
+ reason = "Not Found"
+ elif status == 500:
+ reason = "Internal Server Error"
+ elif status == 503:
+ reason = "Service Unavailable"
+ return (
+ "HTTP/1.1 "
+ + String(status)
+ + " "
+ + reason
+ + "\r\ncontent-type: application/json\r\ncontent-length: "
+ + String(body.byte_length())
+ + "\r\nconnection: close\r\n\r\n"
+ + body
+ )
+
+
+def _send(mut stream: TcpStream, status: Int, body: String) raises:
+ var response = _response(status, body)
+ stream.write_all(Span[UInt8, _](response.as_bytes()))
+
+
+def _handle_health(mut stream: TcpStream, mode: String) raises:
+ if mode == "health_non_2xx":
+ _send(stream, 503, '{"status":"unavailable"}')
+ else:
+ _send(stream, 200, '{"status":"ok"}')
+
+
+def _handle_chat_completions(mut stream: TcpStream, mode: String) raises:
+ if mode == "query_rewrite_ok":
+ _send(stream, 200, _chat_completion(_query_rewrite_analysis()))
+ elif mode == "query_rewrite_non_2xx":
+ _send(stream, 503, '{"error":{"message":"provider unavailable"}}')
+ elif mode == "query_rewrite_invalid_json":
+ _send(stream, 200, '{"choices":[{"message":{"content":"not json"}}]}')
+ elif mode == "query_rewrite_schema_invalid":
+ var body = (
+ '{"original_text":"local apples pickup weekend",'
+ '"normalized_text":"local apples pickup weekend",'
+ '"query_terms":["apples","pickup","weekend"],'
+ '"normalization_signals":["lowercase","local_intent_detected"],'
+ '"ranking_hints":["prefer_local_results","prefer_pickup"],'
+ '"extracted_filters":{'
+ '"local_intent":true,'
+ '"fulfillment":"pickup",'
+ '"time_window":"weekend"'
+ "}}"
+ )
+ _send(stream, 200, _chat_completion(body))
+ elif mode == "query_rewrite_empty_choices":
+ _send(stream, 200, '{"choices":[]}')
+ elif mode == "query_rewrite_missing_content":
+ _send(stream, 200, '{"choices":[{"message":{}}]}')
+ elif mode == "query_rewrite_error_payload":
+ _send(stream, 200, '{"error":{"message":"provider refusal"}}')
+ elif mode == "query_rewrite_timeout":
+ usleep(2_000_000)
+ _send(stream, 200, _chat_completion(_query_rewrite_analysis()))
+ else:
+ _send(stream, 500, '{"error":"unsupported_mode"}')
+
+
+def _handle_request(mut stream: TcpStream, mode: String) raises:
+ var request = _read_request(stream)
+ var path = _request_path(request)
+ if path == "/health":
+ _handle_health(stream, mode)
+ elif path == "/v1/chat/completions":
+ _handle_chat_completions(stream, mode)
+ else:
+ _send(stream, 404, '{"error":"not_found"}')
+ stream.close()
+
+
+def _serve_max_local_stub(port: Int, mode: String, requests: Int) raises:
+ var listener = TcpListener.bind(SocketAddr.localhost(UInt16(port)))
+ _write(1, "ready\n")
+ for _ in range(requests):
+ var stream = listener.accept()
+ _handle_request(stream, mode)
+ listener.close()
+
+
struct SpawnedMaxLocalStub(Movable):
var pid: Int
@@ -61,43 +203,8 @@ def spawn_max_local_stub(
port: Int, mode: String, requests: Int
) raises -> SpawnedMaxLocalStub:
var stdout_pipe = Pipe()
- var command = String("python")
- var entrypoint = String("tests/max_local_http_stub.py")
- var arg_port_flag = String("--port")
- var arg_port = String(port)
- var arg_mode_flag = String("--mode")
- var arg_mode = String(mode)
- var arg_requests_flag = String("--requests")
- var arg_requests = String(requests)
- var argv = List[Optional[CStringSlice[ImmutAnyOrigin]]](length=9, fill={})
- argv[0] = rebind[CStringSlice[ImmutAnyOrigin]](
- command.as_c_string_slice()
- )
- argv[1] = rebind[CStringSlice[ImmutAnyOrigin]](
- entrypoint.as_c_string_slice()
- )
- argv[2] = rebind[CStringSlice[ImmutAnyOrigin]](
- arg_port_flag.as_c_string_slice()
- )
- argv[3] = rebind[CStringSlice[ImmutAnyOrigin]](
- arg_port.as_c_string_slice()
- )
- argv[4] = rebind[CStringSlice[ImmutAnyOrigin]](
- arg_mode_flag.as_c_string_slice()
- )
- argv[5] = rebind[CStringSlice[ImmutAnyOrigin]](
- arg_mode.as_c_string_slice()
- )
- argv[6] = rebind[CStringSlice[ImmutAnyOrigin]](
- arg_requests_flag.as_c_string_slice()
- )
- argv[7] = rebind[CStringSlice[ImmutAnyOrigin]](
- arg_requests.as_c_string_slice()
- )
var stdout_read_fd = c_int(stdout_pipe.fd_in.value().value)
var stdout_write_fd = c_int(stdout_pipe.fd_out.value().value)
- var command_ptr = command.as_c_string_slice().unsafe_ptr()
- var argv_ptr = argv.unsafe_ptr()
var pid = _fork()
if pid < 0:
@@ -108,8 +215,11 @@ def spawn_max_local_stub(
_exit_child(c_int(126))
_ = close(stdout_read_fd)
_ = close(stdout_write_fd)
- _ = external_call["execvp", c_int](command_ptr, argv_ptr)
- _exit_child(c_int(127))
+ try:
+ _serve_max_local_stub(port, mode, requests)
+ _exit_child(c_int(0))
+ except:
+ _exit_child(c_int(125))
stdout_pipe.set_input_only()
var ready_line = _read_pipe_line(stdout_pipe)