provider_execution.mojo (9406B)
1 from std.collections import List, Optional 2 from std.time import perf_counter_ns 3 4 from json import Value 5 6 from hyf_core.backends.selector import ( 7 execute_capability as execute_backend_capability, 8 ) 9 from hyf_core.capabilities.query_analysis import ( 10 QueryAnalysis, 11 analyze_query_text, 12 parse_query_rewrite_request, 13 query_signal_tags, 14 ) 15 from hyf_core.capabilities.query_rewrite import ( 16 build_query_rewrite_deterministic_fallback_meta, 17 build_query_rewrite_output, 18 ) 19 from hyf_core.errors import ( 20 CapabilityResult, 21 failed_capability, 22 invalid_input_error, 23 successful_capability, 24 ) 25 from hyf_core.provenance import ( 26 CoreResponseMeta, 27 ExecutionProvenance, 28 ProvenanceFallback, 29 ProvenanceSourceRef, 30 ) 31 from hyf_core.request_context import ( 32 RequestContext, 33 assisted_execution_requested, 34 ) 35 from hyf_provider.config import ( 36 MaxLocalProviderConfig, 37 max_local_provider_config_from_runtime, 38 ) 39 from hyf_provider.max_local import ( 40 MaxLocalQueryRewriteResult, 41 max_local_provider_status, 42 try_execute_query_rewrite_via_max_local_provider, 43 ) 44 from hyf_runtime.config import ( 45 HyfLoadedRuntimeConfig, 46 assisted_execution_enabled, 47 assisted_runtime_configured, 48 ) 49 from hyf_runtime.startup import RuntimeStartupContext 50 51 52 def _source_refs( 53 context: RequestContext, capability_name: String 54 ) -> List[ProvenanceSourceRef]: 55 var source_refs = List[ProvenanceSourceRef]() 56 source_refs.append( 57 ProvenanceSourceRef( 58 source_kind="local_input", 59 source_ref=capability_name + ":input", 60 ) 61 ) 62 if context.scope: 63 source_refs.append( 64 ProvenanceSourceRef( 65 source_kind="request_scope", 66 source_ref="request_context.scope", 67 ) 68 ) 69 return source_refs^ 70 71 72 def _provider_meta( 73 context: RequestContext, result: MaxLocalQueryRewriteResult 74 ) -> CoreResponseMeta: 75 var provenance: Optional[ExecutionProvenance] = None 76 if context.return_provenance: 77 provenance = ExecutionProvenance( 78 kind="assisted", 79 signal_tags=query_signal_tags(result.analysis), 80 source_refs=_source_refs(context, "query_rewrite"), 81 fallback=None, 82 evidence_set_id=None, 83 ) 84 85 return CoreResponseMeta( 86 execution_mode="assisted", 87 backend="provider_runtime", 88 provider=Optional[String](String(result.provider)), 89 route=Optional[String](String(result.route)), 90 model=Optional[String](String(result.model)), 91 latency_ms=Optional[Int](result.latency_ms), 92 schema_version=Optional[Int](result.schema_version), 93 prompt_version=Optional[String](String(result.prompt_version)), 94 fallback_kind=None, 95 fallback_reason=None, 96 provenance=provenance^, 97 ) 98 99 100 def _provider_runtime_config_fallback_reason( 101 config: HyfLoadedRuntimeConfig, 102 ) -> Optional[String]: 103 if config.load_state == "invalid": 104 return Optional[String]("invalid_config") 105 if not assisted_execution_enabled(config): 106 return Optional[String]("disabled_by_runtime_config") 107 if not assisted_runtime_configured(config): 108 return Optional[String]("provider_unconfigured") 109 return Optional[String](None) 110 111 112 def _effective_provider_budget_ms( 113 config: MaxLocalProviderConfig, context: RequestContext 114 ) -> Int: 115 var budget_ms = config.request_timeout_ms 116 if context.deadline_ms > 0 and context.deadline_ms < budget_ms: 117 budget_ms = context.deadline_ms 118 return budget_ms 119 120 121 def _provider_config_with_timeout( 122 config: MaxLocalProviderConfig, timeout_ms: Int 123 ) -> MaxLocalProviderConfig: 124 var capped = config.copy() 125 capped.request_timeout_ms = timeout_ms 126 return capped^ 127 128 129 def _remaining_provider_budget_ms(start_ns: UInt, budget_ms: Int) -> Int: 130 var elapsed_ms = Int((perf_counter_ns() - start_ns) // 1_000_000) 131 var remaining_ms = budget_ms - elapsed_ms 132 if remaining_ms <= 0: 133 return 0 134 return remaining_ms 135 136 137 def _business_provider_status_reason(reason: String) -> String: 138 if reason == "non_2xx": 139 return "provider_non_2xx" 140 return String(reason) 141 142 143 def _query_rewrite_fallback( 144 input: Value, 145 context: RequestContext, 146 fallback_kind: String, 147 reason: String, 148 ) raises -> CapabilityResult: 149 try: 150 var request = parse_query_rewrite_request(input) 151 var analysis = analyze_query_text(request.text, context) 152 return successful_capability( 153 build_query_rewrite_output(analysis), 154 meta=build_query_rewrite_deterministic_fallback_meta( 155 context, 156 analysis, 157 fallback_kind, 158 reason, 159 ), 160 ) 161 except e: 162 return failed_capability(invalid_input_error(String(e))) 163 164 165 def _with_deterministic_assisted_fallback_meta( 166 result: CapabilityResult, fallback_kind: String, reason: String 167 ) -> CapabilityResult: 168 if not result.success: 169 return result.copy() 170 171 var success = result.success.value().copy() 172 if not success.meta: 173 return result.copy() 174 175 var meta = success.meta.value().copy() 176 meta.fallback_kind = Optional[String](String(fallback_kind)) 177 meta.fallback_reason = Optional[String](String(reason)) 178 if meta.provenance: 179 var provenance = meta.provenance.value().copy() 180 provenance.fallback = Optional[ProvenanceFallback]( 181 ProvenanceFallback( 182 fallback_kind=String(fallback_kind), reason=String(reason) 183 ) 184 ) 185 meta.provenance = Optional[ExecutionProvenance](provenance^) 186 return successful_capability(success.output, meta=meta^) 187 188 189 def _execute_query_rewrite_with_provider( 190 input: Value, 191 context: RequestContext, 192 runtime_context: RuntimeStartupContext, 193 ) raises -> CapabilityResult: 194 var config_fallback_reason = _provider_runtime_config_fallback_reason( 195 runtime_context.config 196 ) 197 if config_fallback_reason: 198 return _query_rewrite_fallback( 199 input, 200 context, 201 "provider_runtime", 202 String(config_fallback_reason.value()), 203 ) 204 205 try: 206 var provider_config = max_local_provider_config_from_runtime( 207 runtime_context.config 208 ) 209 var budget_ms = _effective_provider_budget_ms( 210 provider_config, context 211 ) 212 var budget_start_ns = perf_counter_ns() 213 var provider_status = max_local_provider_status( 214 _provider_config_with_timeout(provider_config, budget_ms) 215 ) 216 if provider_status.state != "ready": 217 return _query_rewrite_fallback( 218 input, 219 context, 220 "provider_runtime", 221 _business_provider_status_reason(provider_status.reason), 222 ) 223 224 var remaining_ms = _remaining_provider_budget_ms( 225 budget_start_ns, budget_ms 226 ) 227 if remaining_ms <= 0: 228 return _query_rewrite_fallback( 229 input, 230 context, 231 "provider_runtime", 232 "timeout", 233 ) 234 235 var request = parse_query_rewrite_request(input) 236 remaining_ms = _remaining_provider_budget_ms( 237 budget_start_ns, budget_ms 238 ) 239 if remaining_ms <= 0: 240 return _query_rewrite_fallback( 241 input, 242 context, 243 "provider_runtime", 244 "timeout", 245 ) 246 247 provider_config = _provider_config_with_timeout( 248 provider_config, remaining_ms 249 ) 250 var outcome = try_execute_query_rewrite_via_max_local_provider( 251 provider_config, request.text, context 252 ) 253 if outcome.failure: 254 return _query_rewrite_fallback( 255 input, 256 context, 257 "provider_runtime", 258 String(outcome.failure.value().reason), 259 ) 260 if not outcome.result: 261 return _query_rewrite_fallback( 262 input, 263 context, 264 "provider_runtime", 265 "provider_error", 266 ) 267 268 var result = outcome.result.value().copy() 269 return successful_capability( 270 build_query_rewrite_output(result.analysis), 271 meta=_provider_meta(context, result), 272 ) 273 except e: 274 return _query_rewrite_fallback( 275 input, 276 context, 277 "provider_runtime", 278 "provider_error", 279 ) 280 281 282 def execute_runtime_aware_business_capability( 283 capability_id: String, 284 input: Value, 285 context: RequestContext, 286 runtime_context: RuntimeStartupContext, 287 ) raises -> CapabilityResult: 288 if not assisted_execution_requested(context): 289 return execute_backend_capability(capability_id, input, context) 290 291 if capability_id == "query_rewrite": 292 return _execute_query_rewrite_with_provider( 293 input, context, runtime_context 294 ) 295 296 var deterministic_context = context.copy() 297 deterministic_context.execution_mode_preference = "deterministic" 298 return _with_deterministic_assisted_fallback_meta( 299 execute_backend_capability( 300 capability_id, input, deterministic_context 301 ), 302 "provider_runtime", 303 "unsupported_capability", 304 )