hyf

Context-aware query service for Radroots
git clone https://radroots.dev/git/hyf.git
Log | Files | Refs | README | LICENSE

max_local_process_helper.mojo (8310B)


      1 from std.ffi import c_int, c_size_t, c_ssize_t, external_call
      2 from std.os import Pipe, Process
      3 from std.sys._libc import close
      4 
      5 from flare.net import SocketAddr
      6 from flare.tcp import TcpListener
      7 from flare.tcp import TcpStream
      8 from flare.utils import usleep
      9 
     10 
     11 def _dup2(oldfd: c_int, newfd: c_int) -> c_int:
     12     return external_call["dup2", c_int](oldfd, newfd)
     13 
     14 
     15 @always_inline
     16 def _fork() -> c_int:
     17     return external_call["fork", c_int]()
     18 
     19 
     20 @always_inline
     21 def _exit_child(code: c_int):
     22     _ = external_call["_exit", c_int](code)
     23 
     24 
     25 def _write(fd: Int, text: String):
     26     _ = external_call["write", c_ssize_t](
     27         fd, text.as_bytes().unsafe_ptr(), c_size_t(text.byte_length())
     28     )
     29 
     30 
     31 def _read_pipe_line(mut pipe: Pipe) raises -> String:
     32     var buffer = InlineArray[Byte, 1](fill=0)
     33     var output = String("")
     34     while True:
     35         var read = pipe.read_bytes(Span(buffer))
     36         if read == 0:
     37             break
     38         var chunk = String(
     39             from_utf8=Span(ptr=buffer.unsafe_ptr(), length=Int(read))
     40         )
     41         if chunk == "\n":
     42             break
     43         output += chunk
     44     return output^
     45 
     46 
     47 def _read_request(mut stream: TcpStream) raises -> String:
     48     var buffer = List[UInt8]()
     49     buffer.resize(8192, 0)
     50     var n = stream.read(buffer.unsafe_ptr(), len(buffer))
     51     if n <= 0:
     52         return ""
     53     return String(unsafe_from_utf8=buffer[:n])
     54 
     55 
     56 def _request_path(request: String) -> String:
     57     var line_end = request.find("\r\n")
     58     if line_end < 0:
     59         return ""
     60     var first_line = String(request[byte=0:line_end])
     61     var first_space = first_line.find(" ")
     62     if first_space < 0:
     63         return ""
     64     var rest = String(first_line[byte=first_space + 1:])
     65     var second_space = rest.find(" ")
     66     if second_space < 0:
     67         return ""
     68     return String(rest[byte=0:second_space])
     69 
     70 
     71 def _json_string(value: String) -> String:
     72     return '"' + value.replace("\\", "\\\\").replace('"', '\\"') + '"'
     73 
     74 
     75 def _query_rewrite_analysis() -> String:
     76     return (
     77         '{"original_text":"local apples pickup weekend",'
     78         '"normalized_text":"local apples pickup weekend",'
     79         '"rewritten_text":"apples pickup weekend",'
     80         '"query_terms":["apples","pickup","weekend"],'
     81         '"normalization_signals":["lowercase","local_intent_detected"],'
     82         '"ranking_hints":["prefer_local_results","prefer_pickup"],'
     83         '"extracted_filters":{'
     84         '"local_intent":true,'
     85         '"fulfillment":"pickup",'
     86         '"time_window":"weekend"'
     87         "}}"
     88     )
     89 
     90 
     91 def _chat_completion(body: String) -> String:
     92     return '{"choices":[{"message":{"content":' + _json_string(body) + "}}]}"
     93 
     94 
     95 def _response(status: Int, body: String) -> String:
     96     var reason = "OK"
     97     if status == 404:
     98         reason = "Not Found"
     99     elif status == 500:
    100         reason = "Internal Server Error"
    101     elif status == 503:
    102         reason = "Service Unavailable"
    103     return (
    104         "HTTP/1.1 "
    105         + String(status)
    106         + " "
    107         + reason
    108         + "\r\ncontent-type: application/json\r\ncontent-length: "
    109         + String(body.byte_length())
    110         + "\r\nconnection: close\r\n\r\n"
    111         + body
    112     )
    113 
    114 
    115 def _send(mut stream: TcpStream, status: Int, body: String) raises:
    116     var response = _response(status, body)
    117     stream.write_all(Span[UInt8, _](response.as_bytes()))
    118 
    119 
    120 def _send_raw(mut stream: TcpStream, response: String) raises:
    121     stream.write_all(Span[UInt8, _](response.as_bytes()))
    122 
    123 
    124 def _handle_health(mut stream: TcpStream, mode: String) raises:
    125     if mode == "health_non_2xx":
    126         _send(stream, 503, '{"status":"unavailable"}')
    127     elif mode == "health_timeout":
    128         usleep(1_000_000)
    129     elif mode == "health_malformed_http":
    130         _send_raw(stream, "not an http response\r\n\r\n")
    131     elif mode == "query_rewrite_remaining_deadline_timeout":
    132         usleep(200_000)
    133         _send(stream, 200, '{"status":"ok"}')
    134     else:
    135         _send(stream, 200, '{"status":"ok"}')
    136 
    137 
    138 def _handle_chat_completions(mut stream: TcpStream, mode: String) raises:
    139     if mode == "query_rewrite_ok":
    140         _send(stream, 200, _chat_completion(_query_rewrite_analysis()))
    141     elif mode == "query_rewrite_non_2xx":
    142         _send(stream, 503, '{"error":{"message":"provider unavailable"}}')
    143     elif mode == "query_rewrite_invalid_json":
    144         _send(stream, 200, '{"choices":[{"message":{"content":"not json"}}]}')
    145     elif mode == "query_rewrite_schema_invalid":
    146         var body = (
    147             '{"original_text":"local apples pickup weekend",'
    148             '"normalized_text":"local apples pickup weekend",'
    149             '"query_terms":["apples","pickup","weekend"],'
    150             '"normalization_signals":["lowercase","local_intent_detected"],'
    151             '"ranking_hints":["prefer_local_results","prefer_pickup"],'
    152             '"extracted_filters":{'
    153             '"local_intent":true,'
    154             '"fulfillment":"pickup",'
    155             '"time_window":"weekend"'
    156             "}}"
    157         )
    158         _send(stream, 200, _chat_completion(body))
    159     elif mode == "query_rewrite_top_level_string":
    160         _send(stream, 200, '"not object"')
    161     elif mode == "query_rewrite_top_level_array":
    162         _send(stream, 200, "[]")
    163     elif mode == "query_rewrite_top_level_null":
    164         _send(stream, 200, "null")
    165     elif mode == "query_rewrite_empty_choices":
    166         _send(stream, 200, '{"choices":[]}')
    167     elif mode == "query_rewrite_missing_content":
    168         _send(stream, 200, '{"choices":[{"message":{}}]}')
    169     elif mode == "query_rewrite_error_payload":
    170         _send(stream, 200, '{"error":{"message":"provider refusal"}}')
    171     elif mode == "query_rewrite_timeout":
    172         usleep(2_000_000)
    173         _send(stream, 200, _chat_completion(_query_rewrite_analysis()))
    174     elif mode == "query_rewrite_remaining_deadline_timeout":
    175         usleep(400_000)
    176         _send(stream, 200, _chat_completion(_query_rewrite_analysis()))
    177     elif mode == "query_rewrite_malformed_http":
    178         _send_raw(stream, "not an http response\r\n\r\n")
    179     else:
    180         _send(stream, 500, '{"error":"unsupported_mode"}')
    181 
    182 
    183 def _handle_request(mut stream: TcpStream, mode: String) raises:
    184     var request = _read_request(stream)
    185     var path = _request_path(request)
    186     if path == "/health":
    187         _handle_health(stream, mode)
    188     elif path == "/v1/chat/completions":
    189         _handle_chat_completions(stream, mode)
    190     else:
    191         _send(stream, 404, '{"error":"not_found"}')
    192     stream.close()
    193 
    194 
    195 def _serve_max_local_stub(port: Int, mode: String, requests: Int) raises:
    196     var listener = TcpListener.bind(SocketAddr.localhost(UInt16(port)))
    197     _write(1, "ready\n")
    198     for _ in range(requests):
    199         var stream = listener.accept()
    200         _handle_request(stream, mode)
    201     listener.close()
    202 
    203 
    204 struct SpawnedMaxLocalStub(Movable):
    205     var pid: Int
    206 
    207     def __init__(out self, pid: Int):
    208         self.pid = pid
    209 
    210     def wait(mut self) raises:
    211         var process = Process(self.pid)
    212         var status = process.wait()
    213         if not status.exit_code or status.exit_code.value() != 0:
    214             raise Error("max_local stub exited unexpectedly")
    215 
    216 
    217 def reserve_loopback_port() raises -> Int:
    218     var listener = TcpListener.bind(SocketAddr.localhost(0))
    219     var port = Int(listener.local_addr().port)
    220     listener.close()
    221     return port
    222 
    223 
    224 def spawn_max_local_stub(
    225     port: Int, mode: String, requests: Int
    226 ) raises -> SpawnedMaxLocalStub:
    227     var stdout_pipe = Pipe()
    228     var stdout_read_fd = c_int(stdout_pipe.fd_in.value().value)
    229     var stdout_write_fd = c_int(stdout_pipe.fd_out.value().value)
    230 
    231     var pid = _fork()
    232     if pid < 0:
    233         raise Error("failed to spawn max_local stub")
    234 
    235     if pid == 0:
    236         if _dup2(stdout_write_fd, 1) < 0:
    237             _exit_child(c_int(126))
    238         _ = close(stdout_read_fd)
    239         _ = close(stdout_write_fd)
    240         try:
    241             _serve_max_local_stub(port, mode, requests)
    242             _exit_child(c_int(0))
    243         except:
    244             _exit_child(c_int(125))
    245 
    246     stdout_pipe.set_input_only()
    247     var ready_line = _read_pipe_line(stdout_pipe)
    248     if ready_line != "ready":
    249         stdout_pipe.set_output_only()
    250         var process = Process(Int(pid))
    251         _ = process.wait()
    252         raise Error("max_local stub failed to report ready")
    253 
    254     stdout_pipe.set_output_only()
    255     return SpawnedMaxLocalStub(Int(pid))