max_local.mojo (6353B)
1 from std.collections import Optional 2 3 from json import Value, loads 4 5 from hyf_assist.contract import max_local_query_rewrite_route 6 from hyf_core.capabilities.query_analysis import QueryAnalysis 7 from hyf_core.request_context import RequestContext 8 from hyf_provider.client import post_max_local_chat_completion 9 from hyf_provider.config import MaxLocalProviderConfig 10 from hyf_provider.health import resolve_max_local_provider_status 11 from hyf_provider.result import ( 12 MaxLocalProviderStatus, 13 parse_query_analysis_from_chat_completion, 14 ) 15 from hyf_provider.schema import ( 16 build_query_rewrite_request_body, 17 query_rewrite_prompt_version, 18 query_rewrite_schema_version, 19 ) 20 21 22 @fieldwise_init 23 struct MaxLocalQueryRewriteResult(Copyable, Movable): 24 var analysis: QueryAnalysis 25 var provider: String 26 var route: String 27 var model: String 28 var latency_ms: Int 29 var schema_version: Int 30 var prompt_version: String 31 32 33 @fieldwise_init 34 struct MaxLocalQueryRewriteFailure(Copyable, Movable): 35 var kind: String 36 var reason: String 37 38 39 @fieldwise_init 40 struct MaxLocalQueryRewriteOutcome(Copyable, Movable): 41 var result: Optional[MaxLocalQueryRewriteResult] 42 var failure: Optional[MaxLocalQueryRewriteFailure] 43 44 45 def _query_rewrite_success_outcome( 46 result: MaxLocalQueryRewriteResult 47 ) -> MaxLocalQueryRewriteOutcome: 48 return MaxLocalQueryRewriteOutcome( 49 result=Optional[MaxLocalQueryRewriteResult](result.copy()), 50 failure=Optional[MaxLocalQueryRewriteFailure](None), 51 ) 52 53 54 def _query_rewrite_failure_outcome( 55 kind: String, reason: String 56 ) -> MaxLocalQueryRewriteOutcome: 57 return MaxLocalQueryRewriteOutcome( 58 result=Optional[MaxLocalQueryRewriteResult](None), 59 failure=Optional[MaxLocalQueryRewriteFailure]( 60 MaxLocalQueryRewriteFailure( 61 kind=String(kind), reason=String(reason) 62 ) 63 ), 64 ) 65 66 67 def max_local_query_rewrite_failure_from_reason( 68 reason: String, 69 ) -> MaxLocalQueryRewriteFailure: 70 if reason == "invalid_url": 71 return MaxLocalQueryRewriteFailure( 72 kind="transport", reason="invalid_url" 73 ) 74 if reason == "timeout": 75 return MaxLocalQueryRewriteFailure( 76 kind="transport", reason="timeout" 77 ) 78 if reason == "connection_failed": 79 return MaxLocalQueryRewriteFailure( 80 kind="transport", reason="connection_failed" 81 ) 82 if reason == "unknown_transport": 83 return MaxLocalQueryRewriteFailure( 84 kind="provider", reason="provider_error" 85 ) 86 if reason == "provider_non_2xx": 87 return MaxLocalQueryRewriteFailure( 88 kind="http_status", reason="provider_non_2xx" 89 ) 90 if reason == "provider_error_payload": 91 return MaxLocalQueryRewriteFailure( 92 kind="provider_payload", reason="provider_error_payload" 93 ) 94 if reason == "provider_invalid_json": 95 return MaxLocalQueryRewriteFailure( 96 kind="provider_payload", reason="provider_invalid_json" 97 ) 98 if reason == "provider_schema_invalid": 99 return MaxLocalQueryRewriteFailure( 100 kind="provider_payload", reason="provider_schema_invalid" 101 ) 102 if reason == "provider_empty_choices": 103 return MaxLocalQueryRewriteFailure( 104 kind="provider_payload", reason="provider_empty_choices" 105 ) 106 if reason == "provider_missing_content": 107 return MaxLocalQueryRewriteFailure( 108 kind="provider_payload", reason="provider_missing_content" 109 ) 110 return MaxLocalQueryRewriteFailure( 111 kind="provider", reason="provider_error" 112 ) 113 114 115 def _load_chat_completion_response_json(text: String) raises -> Value: 116 try: 117 return loads(text) 118 except: 119 raise Error("provider_invalid_json") 120 121 122 def _parse_query_analysis_from_body(text: String) raises -> QueryAnalysis: 123 return parse_query_analysis_from_chat_completion( 124 _load_chat_completion_response_json(text) 125 ) 126 127 128 def execute_query_rewrite_via_max_local_provider( 129 config: MaxLocalProviderConfig, text: String, context: RequestContext 130 ) raises -> MaxLocalQueryRewriteResult: 131 var outcome = try_execute_query_rewrite_via_max_local_provider( 132 config, text, context 133 ) 134 if outcome.result: 135 return outcome.result.value().copy() 136 if outcome.failure: 137 raise Error(String(outcome.failure.value().reason)) 138 raise Error("provider_error") 139 140 141 def try_execute_query_rewrite_via_max_local_provider( 142 config: MaxLocalProviderConfig, text: String, context: RequestContext 143 ) -> MaxLocalQueryRewriteOutcome: 144 var request_body: Value 145 try: 146 request_body = build_query_rewrite_request_body(config, text, context) 147 except: 148 return _query_rewrite_failure_outcome("provider", "provider_error") 149 150 var transport = post_max_local_chat_completion( 151 config, 152 request_body^, 153 ) 154 if transport.failure: 155 var failure = max_local_query_rewrite_failure_from_reason( 156 transport.failure.value().reason 157 ) 158 return _query_rewrite_failure_outcome( 159 String(failure.kind), String(failure.reason) 160 ) 161 162 if transport.response: 163 try: 164 var response = transport.response.value().copy() 165 var analysis = _parse_query_analysis_from_body(response.body_text) 166 return _query_rewrite_success_outcome( 167 MaxLocalQueryRewriteResult( 168 analysis=analysis^, 169 provider="max_local", 170 route=max_local_query_rewrite_route(), 171 model=String(config.model), 172 latency_ms=response.latency_ms, 173 schema_version=query_rewrite_schema_version(), 174 prompt_version=query_rewrite_prompt_version(), 175 ) 176 ) 177 except e: 178 var failure = max_local_query_rewrite_failure_from_reason( 179 String(e) 180 ) 181 return _query_rewrite_failure_outcome( 182 String(failure.kind), String(failure.reason) 183 ) 184 185 return _query_rewrite_failure_outcome("provider", "provider_error") 186 187 188 def max_local_provider_status( 189 config: MaxLocalProviderConfig, 190 ) -> MaxLocalProviderStatus: 191 return resolve_max_local_provider_status(config)