max_local_process_helper.mojo (8310B)
1 from std.ffi import c_int, c_size_t, c_ssize_t, external_call 2 from std.os import Pipe, Process 3 from std.sys._libc import close 4 5 from flare.net import SocketAddr 6 from flare.tcp import TcpListener 7 from flare.tcp import TcpStream 8 from flare.utils import usleep 9 10 11 def _dup2(oldfd: c_int, newfd: c_int) -> c_int: 12 return external_call["dup2", c_int](oldfd, newfd) 13 14 15 @always_inline 16 def _fork() -> c_int: 17 return external_call["fork", c_int]() 18 19 20 @always_inline 21 def _exit_child(code: c_int): 22 _ = external_call["_exit", c_int](code) 23 24 25 def _write(fd: Int, text: String): 26 _ = external_call["write", c_ssize_t]( 27 fd, text.as_bytes().unsafe_ptr(), c_size_t(text.byte_length()) 28 ) 29 30 31 def _read_pipe_line(mut pipe: Pipe) raises -> String: 32 var buffer = InlineArray[Byte, 1](fill=0) 33 var output = String("") 34 while True: 35 var read = pipe.read_bytes(Span(buffer)) 36 if read == 0: 37 break 38 var chunk = String( 39 from_utf8=Span(ptr=buffer.unsafe_ptr(), length=Int(read)) 40 ) 41 if chunk == "\n": 42 break 43 output += chunk 44 return output^ 45 46 47 def _read_request(mut stream: TcpStream) raises -> String: 48 var buffer = List[UInt8]() 49 buffer.resize(8192, 0) 50 var n = stream.read(buffer.unsafe_ptr(), len(buffer)) 51 if n <= 0: 52 return "" 53 return String(unsafe_from_utf8=buffer[:n]) 54 55 56 def _request_path(request: String) -> String: 57 var line_end = request.find("\r\n") 58 if line_end < 0: 59 return "" 60 var first_line = String(request[byte=0:line_end]) 61 var first_space = first_line.find(" ") 62 if first_space < 0: 63 return "" 64 var rest = String(first_line[byte=first_space + 1:]) 65 var second_space = rest.find(" ") 66 if second_space < 0: 67 return "" 68 return String(rest[byte=0:second_space]) 69 70 71 def _json_string(value: String) -> String: 72 return '"' + value.replace("\\", "\\\\").replace('"', '\\"') + '"' 73 74 75 def _query_rewrite_analysis() -> String: 76 return ( 77 '{"original_text":"local apples pickup weekend",' 78 '"normalized_text":"local apples pickup weekend",' 79 '"rewritten_text":"apples pickup weekend",' 80 '"query_terms":["apples","pickup","weekend"],' 81 '"normalization_signals":["lowercase","local_intent_detected"],' 82 '"ranking_hints":["prefer_local_results","prefer_pickup"],' 83 '"extracted_filters":{' 84 '"local_intent":true,' 85 '"fulfillment":"pickup",' 86 '"time_window":"weekend"' 87 "}}" 88 ) 89 90 91 def _chat_completion(body: String) -> String: 92 return '{"choices":[{"message":{"content":' + _json_string(body) + "}}]}" 93 94 95 def _response(status: Int, body: String) -> String: 96 var reason = "OK" 97 if status == 404: 98 reason = "Not Found" 99 elif status == 500: 100 reason = "Internal Server Error" 101 elif status == 503: 102 reason = "Service Unavailable" 103 return ( 104 "HTTP/1.1 " 105 + String(status) 106 + " " 107 + reason 108 + "\r\ncontent-type: application/json\r\ncontent-length: " 109 + String(body.byte_length()) 110 + "\r\nconnection: close\r\n\r\n" 111 + body 112 ) 113 114 115 def _send(mut stream: TcpStream, status: Int, body: String) raises: 116 var response = _response(status, body) 117 stream.write_all(Span[UInt8, _](response.as_bytes())) 118 119 120 def _send_raw(mut stream: TcpStream, response: String) raises: 121 stream.write_all(Span[UInt8, _](response.as_bytes())) 122 123 124 def _handle_health(mut stream: TcpStream, mode: String) raises: 125 if mode == "health_non_2xx": 126 _send(stream, 503, '{"status":"unavailable"}') 127 elif mode == "health_timeout": 128 usleep(1_000_000) 129 elif mode == "health_malformed_http": 130 _send_raw(stream, "not an http response\r\n\r\n") 131 elif mode == "query_rewrite_remaining_deadline_timeout": 132 usleep(200_000) 133 _send(stream, 200, '{"status":"ok"}') 134 else: 135 _send(stream, 200, '{"status":"ok"}') 136 137 138 def _handle_chat_completions(mut stream: TcpStream, mode: String) raises: 139 if mode == "query_rewrite_ok": 140 _send(stream, 200, _chat_completion(_query_rewrite_analysis())) 141 elif mode == "query_rewrite_non_2xx": 142 _send(stream, 503, '{"error":{"message":"provider unavailable"}}') 143 elif mode == "query_rewrite_invalid_json": 144 _send(stream, 200, '{"choices":[{"message":{"content":"not json"}}]}') 145 elif mode == "query_rewrite_schema_invalid": 146 var body = ( 147 '{"original_text":"local apples pickup weekend",' 148 '"normalized_text":"local apples pickup weekend",' 149 '"query_terms":["apples","pickup","weekend"],' 150 '"normalization_signals":["lowercase","local_intent_detected"],' 151 '"ranking_hints":["prefer_local_results","prefer_pickup"],' 152 '"extracted_filters":{' 153 '"local_intent":true,' 154 '"fulfillment":"pickup",' 155 '"time_window":"weekend"' 156 "}}" 157 ) 158 _send(stream, 200, _chat_completion(body)) 159 elif mode == "query_rewrite_top_level_string": 160 _send(stream, 200, '"not object"') 161 elif mode == "query_rewrite_top_level_array": 162 _send(stream, 200, "[]") 163 elif mode == "query_rewrite_top_level_null": 164 _send(stream, 200, "null") 165 elif mode == "query_rewrite_empty_choices": 166 _send(stream, 200, '{"choices":[]}') 167 elif mode == "query_rewrite_missing_content": 168 _send(stream, 200, '{"choices":[{"message":{}}]}') 169 elif mode == "query_rewrite_error_payload": 170 _send(stream, 200, '{"error":{"message":"provider refusal"}}') 171 elif mode == "query_rewrite_timeout": 172 usleep(2_000_000) 173 _send(stream, 200, _chat_completion(_query_rewrite_analysis())) 174 elif mode == "query_rewrite_remaining_deadline_timeout": 175 usleep(400_000) 176 _send(stream, 200, _chat_completion(_query_rewrite_analysis())) 177 elif mode == "query_rewrite_malformed_http": 178 _send_raw(stream, "not an http response\r\n\r\n") 179 else: 180 _send(stream, 500, '{"error":"unsupported_mode"}') 181 182 183 def _handle_request(mut stream: TcpStream, mode: String) raises: 184 var request = _read_request(stream) 185 var path = _request_path(request) 186 if path == "/health": 187 _handle_health(stream, mode) 188 elif path == "/v1/chat/completions": 189 _handle_chat_completions(stream, mode) 190 else: 191 _send(stream, 404, '{"error":"not_found"}') 192 stream.close() 193 194 195 def _serve_max_local_stub(port: Int, mode: String, requests: Int) raises: 196 var listener = TcpListener.bind(SocketAddr.localhost(UInt16(port))) 197 _write(1, "ready\n") 198 for _ in range(requests): 199 var stream = listener.accept() 200 _handle_request(stream, mode) 201 listener.close() 202 203 204 struct SpawnedMaxLocalStub(Movable): 205 var pid: Int 206 207 def __init__(out self, pid: Int): 208 self.pid = pid 209 210 def wait(mut self) raises: 211 var process = Process(self.pid) 212 var status = process.wait() 213 if not status.exit_code or status.exit_code.value() != 0: 214 raise Error("max_local stub exited unexpectedly") 215 216 217 def reserve_loopback_port() raises -> Int: 218 var listener = TcpListener.bind(SocketAddr.localhost(0)) 219 var port = Int(listener.local_addr().port) 220 listener.close() 221 return port 222 223 224 def spawn_max_local_stub( 225 port: Int, mode: String, requests: Int 226 ) raises -> SpawnedMaxLocalStub: 227 var stdout_pipe = Pipe() 228 var stdout_read_fd = c_int(stdout_pipe.fd_in.value().value) 229 var stdout_write_fd = c_int(stdout_pipe.fd_out.value().value) 230 231 var pid = _fork() 232 if pid < 0: 233 raise Error("failed to spawn max_local stub") 234 235 if pid == 0: 236 if _dup2(stdout_write_fd, 1) < 0: 237 _exit_child(c_int(126)) 238 _ = close(stdout_read_fd) 239 _ = close(stdout_write_fd) 240 try: 241 _serve_max_local_stub(port, mode, requests) 242 _exit_child(c_int(0)) 243 except: 244 _exit_child(c_int(125)) 245 246 stdout_pipe.set_input_only() 247 var ready_line = _read_pipe_line(stdout_pipe) 248 if ready_line != "ready": 249 stdout_pipe.set_output_only() 250 var process = Process(Int(pid)) 251 _ = process.wait() 252 raise Error("max_local stub failed to report ready") 253 254 stdout_pipe.set_output_only() 255 return SpawnedMaxLocalStub(Int(pid))