hyf

Context-aware query service for Radroots
git clone https://radroots.dev/git/hyf.git
Log | Files | Refs | README | LICENSE

max_local.mojo (6353B)


      1 from std.collections import Optional
      2 
      3 from json import Value, loads
      4 
      5 from hyf_assist.contract import max_local_query_rewrite_route
      6 from hyf_core.capabilities.query_analysis import QueryAnalysis
      7 from hyf_core.request_context import RequestContext
      8 from hyf_provider.client import post_max_local_chat_completion
      9 from hyf_provider.config import MaxLocalProviderConfig
     10 from hyf_provider.health import resolve_max_local_provider_status
     11 from hyf_provider.result import (
     12     MaxLocalProviderStatus,
     13     parse_query_analysis_from_chat_completion,
     14 )
     15 from hyf_provider.schema import (
     16     build_query_rewrite_request_body,
     17     query_rewrite_prompt_version,
     18     query_rewrite_schema_version,
     19 )
     20 
     21 
     22 @fieldwise_init
     23 struct MaxLocalQueryRewriteResult(Copyable, Movable):
     24     var analysis: QueryAnalysis
     25     var provider: String
     26     var route: String
     27     var model: String
     28     var latency_ms: Int
     29     var schema_version: Int
     30     var prompt_version: String
     31 
     32 
     33 @fieldwise_init
     34 struct MaxLocalQueryRewriteFailure(Copyable, Movable):
     35     var kind: String
     36     var reason: String
     37 
     38 
     39 @fieldwise_init
     40 struct MaxLocalQueryRewriteOutcome(Copyable, Movable):
     41     var result: Optional[MaxLocalQueryRewriteResult]
     42     var failure: Optional[MaxLocalQueryRewriteFailure]
     43 
     44 
     45 def _query_rewrite_success_outcome(
     46     result: MaxLocalQueryRewriteResult
     47 ) -> MaxLocalQueryRewriteOutcome:
     48     return MaxLocalQueryRewriteOutcome(
     49         result=Optional[MaxLocalQueryRewriteResult](result.copy()),
     50         failure=Optional[MaxLocalQueryRewriteFailure](None),
     51     )
     52 
     53 
     54 def _query_rewrite_failure_outcome(
     55     kind: String, reason: String
     56 ) -> MaxLocalQueryRewriteOutcome:
     57     return MaxLocalQueryRewriteOutcome(
     58         result=Optional[MaxLocalQueryRewriteResult](None),
     59         failure=Optional[MaxLocalQueryRewriteFailure](
     60             MaxLocalQueryRewriteFailure(
     61                 kind=String(kind), reason=String(reason)
     62             )
     63         ),
     64     )
     65 
     66 
     67 def max_local_query_rewrite_failure_from_reason(
     68     reason: String,
     69 ) -> MaxLocalQueryRewriteFailure:
     70     if reason == "invalid_url":
     71         return MaxLocalQueryRewriteFailure(
     72             kind="transport", reason="invalid_url"
     73         )
     74     if reason == "timeout":
     75         return MaxLocalQueryRewriteFailure(
     76             kind="transport", reason="timeout"
     77         )
     78     if reason == "connection_failed":
     79         return MaxLocalQueryRewriteFailure(
     80             kind="transport", reason="connection_failed"
     81         )
     82     if reason == "unknown_transport":
     83         return MaxLocalQueryRewriteFailure(
     84             kind="provider", reason="provider_error"
     85         )
     86     if reason == "provider_non_2xx":
     87         return MaxLocalQueryRewriteFailure(
     88             kind="http_status", reason="provider_non_2xx"
     89         )
     90     if reason == "provider_error_payload":
     91         return MaxLocalQueryRewriteFailure(
     92             kind="provider_payload", reason="provider_error_payload"
     93         )
     94     if reason == "provider_invalid_json":
     95         return MaxLocalQueryRewriteFailure(
     96             kind="provider_payload", reason="provider_invalid_json"
     97         )
     98     if reason == "provider_schema_invalid":
     99         return MaxLocalQueryRewriteFailure(
    100             kind="provider_payload", reason="provider_schema_invalid"
    101         )
    102     if reason == "provider_empty_choices":
    103         return MaxLocalQueryRewriteFailure(
    104             kind="provider_payload", reason="provider_empty_choices"
    105         )
    106     if reason == "provider_missing_content":
    107         return MaxLocalQueryRewriteFailure(
    108             kind="provider_payload", reason="provider_missing_content"
    109         )
    110     return MaxLocalQueryRewriteFailure(
    111         kind="provider", reason="provider_error"
    112     )
    113 
    114 
    115 def _load_chat_completion_response_json(text: String) raises -> Value:
    116     try:
    117         return loads(text)
    118     except:
    119         raise Error("provider_invalid_json")
    120 
    121 
    122 def _parse_query_analysis_from_body(text: String) raises -> QueryAnalysis:
    123     return parse_query_analysis_from_chat_completion(
    124         _load_chat_completion_response_json(text)
    125     )
    126 
    127 
    128 def execute_query_rewrite_via_max_local_provider(
    129     config: MaxLocalProviderConfig, text: String, context: RequestContext
    130 ) raises -> MaxLocalQueryRewriteResult:
    131     var outcome = try_execute_query_rewrite_via_max_local_provider(
    132         config, text, context
    133     )
    134     if outcome.result:
    135         return outcome.result.value().copy()
    136     if outcome.failure:
    137         raise Error(String(outcome.failure.value().reason))
    138     raise Error("provider_error")
    139 
    140 
    141 def try_execute_query_rewrite_via_max_local_provider(
    142     config: MaxLocalProviderConfig, text: String, context: RequestContext
    143 ) -> MaxLocalQueryRewriteOutcome:
    144     var request_body: Value
    145     try:
    146         request_body = build_query_rewrite_request_body(config, text, context)
    147     except:
    148         return _query_rewrite_failure_outcome("provider", "provider_error")
    149 
    150     var transport = post_max_local_chat_completion(
    151         config,
    152         request_body^,
    153     )
    154     if transport.failure:
    155         var failure = max_local_query_rewrite_failure_from_reason(
    156             transport.failure.value().reason
    157         )
    158         return _query_rewrite_failure_outcome(
    159             String(failure.kind), String(failure.reason)
    160         )
    161 
    162     if transport.response:
    163         try:
    164             var response = transport.response.value().copy()
    165             var analysis = _parse_query_analysis_from_body(response.body_text)
    166             return _query_rewrite_success_outcome(
    167                 MaxLocalQueryRewriteResult(
    168                     analysis=analysis^,
    169                     provider="max_local",
    170                     route=max_local_query_rewrite_route(),
    171                     model=String(config.model),
    172                     latency_ms=response.latency_ms,
    173                     schema_version=query_rewrite_schema_version(),
    174                     prompt_version=query_rewrite_prompt_version(),
    175                 )
    176             )
    177         except e:
    178             var failure = max_local_query_rewrite_failure_from_reason(
    179                 String(e)
    180             )
    181             return _query_rewrite_failure_outcome(
    182                 String(failure.kind), String(failure.reason)
    183             )
    184 
    185     return _query_rewrite_failure_outcome("provider", "provider_error")
    186 
    187 
    188 def max_local_provider_status(
    189     config: MaxLocalProviderConfig,
    190 ) -> MaxLocalProviderStatus:
    191     return resolve_max_local_provider_status(config)