hyf

Context-aware query service for Radroots
git clone https://radroots.dev/git/hyf.git
Log | Files | Refs | README | LICENSE

query_analysis.mojo (11888B)


      1 from std.collections import List, Optional
      2 
      3 from json import Value, loads
      4 
      5 from hyf_core.provenance import (
      6     CoreResponseMeta,
      7     ExecutionProvenance,
      8     ProvenanceSourceRef,
      9 )
     10 from hyf_core.request_context import RequestContext
     11 
     12 
     13 def _require_object(value: Value, context: String) raises:
     14     if not value.is_object():
     15         raise Error(context + " must be a JSON object")
     16 
     17 
     18 def _require_allowed_keys(
     19     value: Value, key_a: String, key_b: String, context: String
     20 ) raises:
     21     for key in value.object_keys():
     22         if key != key_a and key != key_b:
     23             raise Error(context + " contains unexpected field '" + key + "'")
     24 
     25 
     26 def has_key(value: Value, key: String) -> Bool:
     27     for candidate in value.object_keys():
     28         if candidate == key:
     29             return True
     30     return False
     31 
     32 
     33 def copy_string_list(items: List[String]) -> List[String]:
     34     var copied = List[String]()
     35     for item in items:
     36         copied.append(String(item))
     37     return copied^
     38 
     39 
     40 def string_array_value(items: List[String]) raises -> Value:
     41     var array = loads("[]")
     42     for item in items:
     43         array.append(Value(String(item)))
     44     return array^
     45 
     46 
     47 def collapse_whitespace(text: String) -> String:
     48     var parts = text.split()
     49     var collapsed = String()
     50     var first = True
     51     for part in parts:
     52         if not first:
     53             collapsed += " "
     54         collapsed += String(part)
     55         first = False
     56     return collapsed^
     57 
     58 
     59 def join_strings(items: List[String]) -> String:
     60     var joined = String()
     61     var first = True
     62     for item in items:
     63         if not first:
     64             joined += " "
     65         joined += String(item)
     66         first = False
     67     return joined^
     68 
     69 
     70 def normalize_free_text(text: String, mut signals: List[String]) -> String:
     71     var normalized = text.lower()
     72     if normalized != text:
     73         signals.append("lowercase")
     74 
     75     var replaced = normalized
     76     replaced = replaced.replace(",", " ")
     77     replaced = replaced.replace(".", " ")
     78     replaced = replaced.replace("!", " ")
     79     replaced = replaced.replace("?", " ")
     80     replaced = replaced.replace(":", " ")
     81     replaced = replaced.replace(";", " ")
     82     replaced = replaced.replace("/", " ")
     83     replaced = replaced.replace("\\", " ")
     84     replaced = replaced.replace("(", " ")
     85     replaced = replaced.replace(")", " ")
     86     replaced = replaced.replace("[", " ")
     87     replaced = replaced.replace("]", " ")
     88     replaced = replaced.replace("{", " ")
     89     replaced = replaced.replace("}", " ")
     90     replaced = replaced.replace("\"", " ")
     91     replaced = replaced.replace("'", " ")
     92     replaced = replaced.replace("-", " ")
     93     if replaced != normalized:
     94         signals.append("punctuation_trimmed")
     95 
     96     var collapsed = collapse_whitespace(replaced)
     97     if collapsed != replaced:
     98         signals.append("whitespace_collapsed")
     99 
    100     return collapsed^
    101 
    102 
    103 def contains_token(items: List[String], token: String) -> Bool:
    104     for item in items:
    105         if item == token:
    106             return True
    107     return False
    108 
    109 
    110 def _is_stop_word(token: String) -> Bool:
    111     return (
    112         token == "a"
    113         or token == "an"
    114         or token == "and"
    115         or token == "for"
    116         or token == "from"
    117         or token == "in"
    118         or token == "me"
    119         or token == "near"
    120         or token == "of"
    121         or token == "on"
    122         or token == "the"
    123         or token == "to"
    124         or token == "with"
    125     )
    126 
    127 
    128 @fieldwise_init
    129 struct ExtractedFilters(Copyable, Movable):
    130     var local_intent: Bool
    131     var fulfillment: String
    132     var time_window: String
    133 
    134 
    135 @fieldwise_init
    136 struct QueryAnalysis(Copyable, Movable):
    137     var original_text: String
    138     var normalized_text: String
    139     var rewritten_text: String
    140     var query_terms: List[String]
    141     var normalization_signals: List[String]
    142     var ranking_hints: List[String]
    143     var extracted_filters: ExtractedFilters
    144 
    145 
    146 @fieldwise_init
    147 struct QueryRewriteRequest(Copyable, Movable):
    148     var text: String
    149 
    150 
    151 def extract_text_input(input: Value, capability_name: String) raises -> String:
    152     if not input.is_object():
    153         raise Error(capability_name + " input must be a JSON object")
    154 
    155     if has_key(input, "text"):
    156         var text_value = input["text"]
    157         if not text_value.is_string():
    158             raise Error(
    159                 capability_name + " input field 'text' must be a string"
    160             )
    161         var collapsed = collapse_whitespace(text_value.string_value())
    162         if collapsed == "":
    163             raise Error(capability_name + " input text must not be empty")
    164         return collapsed^
    165     elif has_key(input, "query"):
    166         var query_value = input["query"]
    167         if not query_value.is_string():
    168             raise Error(
    169                 capability_name + " input field 'query' must be a string"
    170             )
    171         var collapsed = collapse_whitespace(query_value.string_value())
    172         if collapsed == "":
    173             raise Error(capability_name + " input text must not be empty")
    174         return collapsed^
    175     else:
    176         raise Error(
    177             capability_name + " input requires 'text' or 'query'"
    178         )
    179 
    180 
    181 def parse_query_rewrite_request(input: Value) raises -> QueryRewriteRequest:
    182     _require_object(input, "query_rewrite input")
    183     _require_allowed_keys(input, "text", "query", "query_rewrite input")
    184 
    185     var has_text = has_key(input, "text")
    186     var has_query = has_key(input, "query")
    187 
    188     if has_text and has_query:
    189         raise Error(
    190             "query_rewrite input must provide exactly one of 'text' or 'query'"
    191         )
    192     if not has_text and not has_query:
    193         raise Error(
    194             "query_rewrite input requires exactly one of 'text' or 'query'"
    195         )
    196 
    197     var source_field = "text" if has_text else "query"
    198     var text_value = input[source_field]
    199     if not text_value.is_string():
    200         raise Error(
    201             "query_rewrite input field '" + source_field + "' must be a string"
    202         )
    203 
    204     var collapsed = collapse_whitespace(text_value.string_value())
    205     if collapsed == "":
    206         raise Error("query_rewrite input text must not be empty")
    207 
    208     return QueryRewriteRequest(text=collapsed)
    209 
    210 
    211 def analyze_query_text(
    212     original_text: String, context: RequestContext
    213 ) -> QueryAnalysis:
    214     var normalized_input = String(original_text)
    215 
    216     var normalization_signals = List[String]()
    217     var normalized_text = normalize_free_text(
    218         normalized_input, normalization_signals
    219     )
    220     var normalized_tokens = normalized_text.split()
    221 
    222     var query_terms = List[String]()
    223     var ranking_hints = List[String]()
    224     var local_intent = False
    225     var fulfillment = "unspecified"
    226     var time_window = "unspecified"
    227     var removed_stop_words = False
    228     var extracted_filter_tokens = False
    229 
    230     for raw_token in normalized_tokens:
    231         var token = String(raw_token)
    232         if token == "":
    233             continue
    234 
    235         if (
    236             token == "near"
    237             or token == "me"
    238             or token == "nearby"
    239             or token == "local"
    240         ):
    241             local_intent = True
    242             extracted_filter_tokens = True
    243             continue
    244 
    245         if token == "pickup" or token == "curbside":
    246             fulfillment = "pickup"
    247             extracted_filter_tokens = True
    248             continue
    249 
    250         if token == "delivery" or token == "ship" or token == "shipping":
    251             fulfillment = "delivery"
    252             extracted_filter_tokens = True
    253             continue
    254 
    255         if token == "weekend" or token == "saturday" or token == "sunday":
    256             time_window = "weekend"
    257             extracted_filter_tokens = True
    258             continue
    259 
    260         if _is_stop_word(token):
    261             removed_stop_words = True
    262             continue
    263 
    264         if not contains_token(query_terms, token):
    265             query_terms.append(token)
    266 
    267     if local_intent:
    268         normalization_signals.append("local_intent_detected")
    269         ranking_hints.append("prefer_local_results")
    270     if fulfillment == "pickup":
    271         normalization_signals.append("pickup_filter_detected")
    272         ranking_hints.append("prefer_pickup")
    273     elif fulfillment == "delivery":
    274         normalization_signals.append("delivery_filter_detected")
    275         ranking_hints.append("prefer_delivery")
    276     if time_window == "weekend":
    277         normalization_signals.append("weekend_filter_detected")
    278         ranking_hints.append("prefer_weekend_availability")
    279     if removed_stop_words:
    280         normalization_signals.append("stopwords_removed")
    281     if extracted_filter_tokens:
    282         normalization_signals.append("filter_tokens_extracted")
    283     if context.scope:
    284         ranking_hints.append("respect_scope")
    285         normalization_signals.append("scope_present")
    286 
    287     if len(query_terms) == 0:
    288         query_terms.append(String(normalized_text))
    289         normalization_signals.append("fallback_to_normalized_query")
    290 
    291     return QueryAnalysis(
    292         original_text=normalized_input,
    293         normalized_text=normalized_text,
    294         rewritten_text=join_strings(query_terms),
    295         query_terms=query_terms^,
    296         normalization_signals=normalization_signals^,
    297         ranking_hints=ranking_hints^,
    298         extracted_filters=ExtractedFilters(
    299             local_intent=local_intent,
    300             fulfillment=fulfillment,
    301             time_window=time_window,
    302         ),
    303     )
    304 
    305 
    306 def analyze_query(
    307     input: Value, context: RequestContext, capability_name: String
    308 ) raises -> QueryAnalysis:
    309     var original_text = extract_text_input(input, capability_name)
    310     return analyze_query_text(original_text, context)
    311 
    312 
    313 def serialize_extracted_filters(filters: ExtractedFilters) raises -> Value:
    314     var value = loads("{}")
    315     value.set("local_intent", Value(filters.local_intent))
    316     value.set("fulfillment", Value(String(filters.fulfillment)))
    317     value.set("time_window", Value(String(filters.time_window)))
    318     return value^
    319 
    320 
    321 def query_signal_tags(analysis: QueryAnalysis) -> List[String]:
    322     var signal_tags = copy_string_list(analysis.normalization_signals)
    323     for hint in analysis.ranking_hints:
    324         signal_tags.append(String(hint))
    325     return signal_tags^
    326 
    327 
    328 def build_deterministic_meta(
    329     context: RequestContext,
    330     capability_name: String,
    331     signal_tags: List[String],
    332     extra_source_refs: List[ProvenanceSourceRef],
    333 ) -> CoreResponseMeta:
    334     var source_refs = List[ProvenanceSourceRef]()
    335     source_refs.append(
    336         ProvenanceSourceRef(
    337             source_kind="local_input",
    338             source_ref=capability_name + ":input",
    339         )
    340     )
    341     for source_ref in extra_source_refs:
    342         source_refs.append(
    343             ProvenanceSourceRef(
    344                 source_kind=String(source_ref.source_kind),
    345                 source_ref=String(source_ref.source_ref),
    346             )
    347         )
    348     if context.scope:
    349         source_refs.append(
    350             ProvenanceSourceRef(
    351                 source_kind="request_scope",
    352                 source_ref="request_context.scope",
    353             )
    354         )
    355 
    356     if context.return_provenance:
    357         return CoreResponseMeta(
    358             execution_mode="deterministic",
    359             backend="heuristic",
    360             provider=None,
    361             route=None,
    362             model=None,
    363             latency_ms=None,
    364             schema_version=Optional[Int](1),
    365             prompt_version=None,
    366             fallback_kind=None,
    367             fallback_reason=None,
    368             provenance=ExecutionProvenance(
    369                 kind="deterministic",
    370                 signal_tags=copy_string_list(signal_tags),
    371                 source_refs=source_refs^,
    372                 fallback=None,
    373                 evidence_set_id=None,
    374             ),
    375         )
    376 
    377     return CoreResponseMeta(
    378         execution_mode="deterministic",
    379         backend="heuristic",
    380         provider=None,
    381         route=None,
    382         model=None,
    383         latency_ms=None,
    384         schema_version=Optional[Int](1),
    385         prompt_version=None,
    386         fallback_kind=None,
    387         fallback_reason=None,
    388         provenance=None,
    389     )