hyf

Context-aware query service for Radroots
git clone https://radroots.dev/git/hyf.git
Log | Files | Refs | README | LICENSE

provider_execution.mojo (9406B)


      1 from std.collections import List, Optional
      2 from std.time import perf_counter_ns
      3 
      4 from json import Value
      5 
      6 from hyf_core.backends.selector import (
      7     execute_capability as execute_backend_capability,
      8 )
      9 from hyf_core.capabilities.query_analysis import (
     10     QueryAnalysis,
     11     analyze_query_text,
     12     parse_query_rewrite_request,
     13     query_signal_tags,
     14 )
     15 from hyf_core.capabilities.query_rewrite import (
     16     build_query_rewrite_deterministic_fallback_meta,
     17     build_query_rewrite_output,
     18 )
     19 from hyf_core.errors import (
     20     CapabilityResult,
     21     failed_capability,
     22     invalid_input_error,
     23     successful_capability,
     24 )
     25 from hyf_core.provenance import (
     26     CoreResponseMeta,
     27     ExecutionProvenance,
     28     ProvenanceFallback,
     29     ProvenanceSourceRef,
     30 )
     31 from hyf_core.request_context import (
     32     RequestContext,
     33     assisted_execution_requested,
     34 )
     35 from hyf_provider.config import (
     36     MaxLocalProviderConfig,
     37     max_local_provider_config_from_runtime,
     38 )
     39 from hyf_provider.max_local import (
     40     MaxLocalQueryRewriteResult,
     41     max_local_provider_status,
     42     try_execute_query_rewrite_via_max_local_provider,
     43 )
     44 from hyf_runtime.config import (
     45     HyfLoadedRuntimeConfig,
     46     assisted_execution_enabled,
     47     assisted_runtime_configured,
     48 )
     49 from hyf_runtime.startup import RuntimeStartupContext
     50 
     51 
     52 def _source_refs(
     53     context: RequestContext, capability_name: String
     54 ) -> List[ProvenanceSourceRef]:
     55     var source_refs = List[ProvenanceSourceRef]()
     56     source_refs.append(
     57         ProvenanceSourceRef(
     58             source_kind="local_input",
     59             source_ref=capability_name + ":input",
     60         )
     61     )
     62     if context.scope:
     63         source_refs.append(
     64             ProvenanceSourceRef(
     65                 source_kind="request_scope",
     66                 source_ref="request_context.scope",
     67             )
     68         )
     69     return source_refs^
     70 
     71 
     72 def _provider_meta(
     73     context: RequestContext, result: MaxLocalQueryRewriteResult
     74 ) -> CoreResponseMeta:
     75     var provenance: Optional[ExecutionProvenance] = None
     76     if context.return_provenance:
     77         provenance = ExecutionProvenance(
     78             kind="assisted",
     79             signal_tags=query_signal_tags(result.analysis),
     80             source_refs=_source_refs(context, "query_rewrite"),
     81             fallback=None,
     82             evidence_set_id=None,
     83         )
     84 
     85     return CoreResponseMeta(
     86         execution_mode="assisted",
     87         backend="provider_runtime",
     88         provider=Optional[String](String(result.provider)),
     89         route=Optional[String](String(result.route)),
     90         model=Optional[String](String(result.model)),
     91         latency_ms=Optional[Int](result.latency_ms),
     92         schema_version=Optional[Int](result.schema_version),
     93         prompt_version=Optional[String](String(result.prompt_version)),
     94         fallback_kind=None,
     95         fallback_reason=None,
     96         provenance=provenance^,
     97     )
     98 
     99 
    100 def _provider_runtime_config_fallback_reason(
    101     config: HyfLoadedRuntimeConfig,
    102 ) -> Optional[String]:
    103     if config.load_state == "invalid":
    104         return Optional[String]("invalid_config")
    105     if not assisted_execution_enabled(config):
    106         return Optional[String]("disabled_by_runtime_config")
    107     if not assisted_runtime_configured(config):
    108         return Optional[String]("provider_unconfigured")
    109     return Optional[String](None)
    110 
    111 
    112 def _effective_provider_budget_ms(
    113     config: MaxLocalProviderConfig, context: RequestContext
    114 ) -> Int:
    115     var budget_ms = config.request_timeout_ms
    116     if context.deadline_ms > 0 and context.deadline_ms < budget_ms:
    117         budget_ms = context.deadline_ms
    118     return budget_ms
    119 
    120 
    121 def _provider_config_with_timeout(
    122     config: MaxLocalProviderConfig, timeout_ms: Int
    123 ) -> MaxLocalProviderConfig:
    124     var capped = config.copy()
    125     capped.request_timeout_ms = timeout_ms
    126     return capped^
    127 
    128 
    129 def _remaining_provider_budget_ms(start_ns: UInt, budget_ms: Int) -> Int:
    130     var elapsed_ms = Int((perf_counter_ns() - start_ns) // 1_000_000)
    131     var remaining_ms = budget_ms - elapsed_ms
    132     if remaining_ms <= 0:
    133         return 0
    134     return remaining_ms
    135 
    136 
    137 def _business_provider_status_reason(reason: String) -> String:
    138     if reason == "non_2xx":
    139         return "provider_non_2xx"
    140     return String(reason)
    141 
    142 
    143 def _query_rewrite_fallback(
    144     input: Value,
    145     context: RequestContext,
    146     fallback_kind: String,
    147     reason: String,
    148 ) raises -> CapabilityResult:
    149     try:
    150         var request = parse_query_rewrite_request(input)
    151         var analysis = analyze_query_text(request.text, context)
    152         return successful_capability(
    153             build_query_rewrite_output(analysis),
    154             meta=build_query_rewrite_deterministic_fallback_meta(
    155                 context,
    156                 analysis,
    157                 fallback_kind,
    158                 reason,
    159             ),
    160         )
    161     except e:
    162         return failed_capability(invalid_input_error(String(e)))
    163 
    164 
    165 def _with_deterministic_assisted_fallback_meta(
    166     result: CapabilityResult, fallback_kind: String, reason: String
    167 ) -> CapabilityResult:
    168     if not result.success:
    169         return result.copy()
    170 
    171     var success = result.success.value().copy()
    172     if not success.meta:
    173         return result.copy()
    174 
    175     var meta = success.meta.value().copy()
    176     meta.fallback_kind = Optional[String](String(fallback_kind))
    177     meta.fallback_reason = Optional[String](String(reason))
    178     if meta.provenance:
    179         var provenance = meta.provenance.value().copy()
    180         provenance.fallback = Optional[ProvenanceFallback](
    181             ProvenanceFallback(
    182                 fallback_kind=String(fallback_kind), reason=String(reason)
    183             )
    184         )
    185         meta.provenance = Optional[ExecutionProvenance](provenance^)
    186     return successful_capability(success.output, meta=meta^)
    187 
    188 
    189 def _execute_query_rewrite_with_provider(
    190     input: Value,
    191     context: RequestContext,
    192     runtime_context: RuntimeStartupContext,
    193 ) raises -> CapabilityResult:
    194     var config_fallback_reason = _provider_runtime_config_fallback_reason(
    195         runtime_context.config
    196     )
    197     if config_fallback_reason:
    198         return _query_rewrite_fallback(
    199             input,
    200             context,
    201             "provider_runtime",
    202             String(config_fallback_reason.value()),
    203         )
    204 
    205     try:
    206         var provider_config = max_local_provider_config_from_runtime(
    207             runtime_context.config
    208         )
    209         var budget_ms = _effective_provider_budget_ms(
    210             provider_config, context
    211         )
    212         var budget_start_ns = perf_counter_ns()
    213         var provider_status = max_local_provider_status(
    214             _provider_config_with_timeout(provider_config, budget_ms)
    215         )
    216         if provider_status.state != "ready":
    217             return _query_rewrite_fallback(
    218                 input,
    219                 context,
    220                 "provider_runtime",
    221                 _business_provider_status_reason(provider_status.reason),
    222             )
    223 
    224         var remaining_ms = _remaining_provider_budget_ms(
    225             budget_start_ns, budget_ms
    226         )
    227         if remaining_ms <= 0:
    228             return _query_rewrite_fallback(
    229                 input,
    230                 context,
    231                 "provider_runtime",
    232                 "timeout",
    233             )
    234 
    235         var request = parse_query_rewrite_request(input)
    236         remaining_ms = _remaining_provider_budget_ms(
    237             budget_start_ns, budget_ms
    238         )
    239         if remaining_ms <= 0:
    240             return _query_rewrite_fallback(
    241                 input,
    242                 context,
    243                 "provider_runtime",
    244                 "timeout",
    245             )
    246 
    247         provider_config = _provider_config_with_timeout(
    248             provider_config, remaining_ms
    249         )
    250         var outcome = try_execute_query_rewrite_via_max_local_provider(
    251             provider_config, request.text, context
    252         )
    253         if outcome.failure:
    254             return _query_rewrite_fallback(
    255                 input,
    256                 context,
    257                 "provider_runtime",
    258                 String(outcome.failure.value().reason),
    259             )
    260         if not outcome.result:
    261             return _query_rewrite_fallback(
    262                 input,
    263                 context,
    264                 "provider_runtime",
    265                 "provider_error",
    266             )
    267 
    268         var result = outcome.result.value().copy()
    269         return successful_capability(
    270             build_query_rewrite_output(result.analysis),
    271             meta=_provider_meta(context, result),
    272         )
    273     except e:
    274         return _query_rewrite_fallback(
    275             input,
    276             context,
    277             "provider_runtime",
    278             "provider_error",
    279         )
    280 
    281 
    282 def execute_runtime_aware_business_capability(
    283     capability_id: String,
    284     input: Value,
    285     context: RequestContext,
    286     runtime_context: RuntimeStartupContext,
    287 ) raises -> CapabilityResult:
    288     if not assisted_execution_requested(context):
    289         return execute_backend_capability(capability_id, input, context)
    290 
    291     if capability_id == "query_rewrite":
    292         return _execute_query_rewrite_with_provider(
    293             input, context, runtime_context
    294         )
    295 
    296     var deterministic_context = context.copy()
    297     deterministic_context.execution_mode_preference = "deterministic"
    298     return _with_deterministic_assisted_fallback_meta(
    299         execute_backend_capability(
    300             capability_id, input, deterministic_context
    301         ),
    302         "provider_runtime",
    303         "unsupported_capability",
    304     )