query_analysis.mojo (11888B)
1 from std.collections import List, Optional 2 3 from json import Value, loads 4 5 from hyf_core.provenance import ( 6 CoreResponseMeta, 7 ExecutionProvenance, 8 ProvenanceSourceRef, 9 ) 10 from hyf_core.request_context import RequestContext 11 12 13 def _require_object(value: Value, context: String) raises: 14 if not value.is_object(): 15 raise Error(context + " must be a JSON object") 16 17 18 def _require_allowed_keys( 19 value: Value, key_a: String, key_b: String, context: String 20 ) raises: 21 for key in value.object_keys(): 22 if key != key_a and key != key_b: 23 raise Error(context + " contains unexpected field '" + key + "'") 24 25 26 def has_key(value: Value, key: String) -> Bool: 27 for candidate in value.object_keys(): 28 if candidate == key: 29 return True 30 return False 31 32 33 def copy_string_list(items: List[String]) -> List[String]: 34 var copied = List[String]() 35 for item in items: 36 copied.append(String(item)) 37 return copied^ 38 39 40 def string_array_value(items: List[String]) raises -> Value: 41 var array = loads("[]") 42 for item in items: 43 array.append(Value(String(item))) 44 return array^ 45 46 47 def collapse_whitespace(text: String) -> String: 48 var parts = text.split() 49 var collapsed = String() 50 var first = True 51 for part in parts: 52 if not first: 53 collapsed += " " 54 collapsed += String(part) 55 first = False 56 return collapsed^ 57 58 59 def join_strings(items: List[String]) -> String: 60 var joined = String() 61 var first = True 62 for item in items: 63 if not first: 64 joined += " " 65 joined += String(item) 66 first = False 67 return joined^ 68 69 70 def normalize_free_text(text: String, mut signals: List[String]) -> String: 71 var normalized = text.lower() 72 if normalized != text: 73 signals.append("lowercase") 74 75 var replaced = normalized 76 replaced = replaced.replace(",", " ") 77 replaced = replaced.replace(".", " ") 78 replaced = replaced.replace("!", " ") 79 replaced = replaced.replace("?", " ") 80 replaced = replaced.replace(":", " ") 81 replaced = replaced.replace(";", " ") 82 replaced = replaced.replace("/", " ") 83 replaced = replaced.replace("\\", " ") 84 replaced = replaced.replace("(", " ") 85 replaced = replaced.replace(")", " ") 86 replaced = replaced.replace("[", " ") 87 replaced = replaced.replace("]", " ") 88 replaced = replaced.replace("{", " ") 89 replaced = replaced.replace("}", " ") 90 replaced = replaced.replace("\"", " ") 91 replaced = replaced.replace("'", " ") 92 replaced = replaced.replace("-", " ") 93 if replaced != normalized: 94 signals.append("punctuation_trimmed") 95 96 var collapsed = collapse_whitespace(replaced) 97 if collapsed != replaced: 98 signals.append("whitespace_collapsed") 99 100 return collapsed^ 101 102 103 def contains_token(items: List[String], token: String) -> Bool: 104 for item in items: 105 if item == token: 106 return True 107 return False 108 109 110 def _is_stop_word(token: String) -> Bool: 111 return ( 112 token == "a" 113 or token == "an" 114 or token == "and" 115 or token == "for" 116 or token == "from" 117 or token == "in" 118 or token == "me" 119 or token == "near" 120 or token == "of" 121 or token == "on" 122 or token == "the" 123 or token == "to" 124 or token == "with" 125 ) 126 127 128 @fieldwise_init 129 struct ExtractedFilters(Copyable, Movable): 130 var local_intent: Bool 131 var fulfillment: String 132 var time_window: String 133 134 135 @fieldwise_init 136 struct QueryAnalysis(Copyable, Movable): 137 var original_text: String 138 var normalized_text: String 139 var rewritten_text: String 140 var query_terms: List[String] 141 var normalization_signals: List[String] 142 var ranking_hints: List[String] 143 var extracted_filters: ExtractedFilters 144 145 146 @fieldwise_init 147 struct QueryRewriteRequest(Copyable, Movable): 148 var text: String 149 150 151 def extract_text_input(input: Value, capability_name: String) raises -> String: 152 if not input.is_object(): 153 raise Error(capability_name + " input must be a JSON object") 154 155 if has_key(input, "text"): 156 var text_value = input["text"] 157 if not text_value.is_string(): 158 raise Error( 159 capability_name + " input field 'text' must be a string" 160 ) 161 var collapsed = collapse_whitespace(text_value.string_value()) 162 if collapsed == "": 163 raise Error(capability_name + " input text must not be empty") 164 return collapsed^ 165 elif has_key(input, "query"): 166 var query_value = input["query"] 167 if not query_value.is_string(): 168 raise Error( 169 capability_name + " input field 'query' must be a string" 170 ) 171 var collapsed = collapse_whitespace(query_value.string_value()) 172 if collapsed == "": 173 raise Error(capability_name + " input text must not be empty") 174 return collapsed^ 175 else: 176 raise Error( 177 capability_name + " input requires 'text' or 'query'" 178 ) 179 180 181 def parse_query_rewrite_request(input: Value) raises -> QueryRewriteRequest: 182 _require_object(input, "query_rewrite input") 183 _require_allowed_keys(input, "text", "query", "query_rewrite input") 184 185 var has_text = has_key(input, "text") 186 var has_query = has_key(input, "query") 187 188 if has_text and has_query: 189 raise Error( 190 "query_rewrite input must provide exactly one of 'text' or 'query'" 191 ) 192 if not has_text and not has_query: 193 raise Error( 194 "query_rewrite input requires exactly one of 'text' or 'query'" 195 ) 196 197 var source_field = "text" if has_text else "query" 198 var text_value = input[source_field] 199 if not text_value.is_string(): 200 raise Error( 201 "query_rewrite input field '" + source_field + "' must be a string" 202 ) 203 204 var collapsed = collapse_whitespace(text_value.string_value()) 205 if collapsed == "": 206 raise Error("query_rewrite input text must not be empty") 207 208 return QueryRewriteRequest(text=collapsed) 209 210 211 def analyze_query_text( 212 original_text: String, context: RequestContext 213 ) -> QueryAnalysis: 214 var normalized_input = String(original_text) 215 216 var normalization_signals = List[String]() 217 var normalized_text = normalize_free_text( 218 normalized_input, normalization_signals 219 ) 220 var normalized_tokens = normalized_text.split() 221 222 var query_terms = List[String]() 223 var ranking_hints = List[String]() 224 var local_intent = False 225 var fulfillment = "unspecified" 226 var time_window = "unspecified" 227 var removed_stop_words = False 228 var extracted_filter_tokens = False 229 230 for raw_token in normalized_tokens: 231 var token = String(raw_token) 232 if token == "": 233 continue 234 235 if ( 236 token == "near" 237 or token == "me" 238 or token == "nearby" 239 or token == "local" 240 ): 241 local_intent = True 242 extracted_filter_tokens = True 243 continue 244 245 if token == "pickup" or token == "curbside": 246 fulfillment = "pickup" 247 extracted_filter_tokens = True 248 continue 249 250 if token == "delivery" or token == "ship" or token == "shipping": 251 fulfillment = "delivery" 252 extracted_filter_tokens = True 253 continue 254 255 if token == "weekend" or token == "saturday" or token == "sunday": 256 time_window = "weekend" 257 extracted_filter_tokens = True 258 continue 259 260 if _is_stop_word(token): 261 removed_stop_words = True 262 continue 263 264 if not contains_token(query_terms, token): 265 query_terms.append(token) 266 267 if local_intent: 268 normalization_signals.append("local_intent_detected") 269 ranking_hints.append("prefer_local_results") 270 if fulfillment == "pickup": 271 normalization_signals.append("pickup_filter_detected") 272 ranking_hints.append("prefer_pickup") 273 elif fulfillment == "delivery": 274 normalization_signals.append("delivery_filter_detected") 275 ranking_hints.append("prefer_delivery") 276 if time_window == "weekend": 277 normalization_signals.append("weekend_filter_detected") 278 ranking_hints.append("prefer_weekend_availability") 279 if removed_stop_words: 280 normalization_signals.append("stopwords_removed") 281 if extracted_filter_tokens: 282 normalization_signals.append("filter_tokens_extracted") 283 if context.scope: 284 ranking_hints.append("respect_scope") 285 normalization_signals.append("scope_present") 286 287 if len(query_terms) == 0: 288 query_terms.append(String(normalized_text)) 289 normalization_signals.append("fallback_to_normalized_query") 290 291 return QueryAnalysis( 292 original_text=normalized_input, 293 normalized_text=normalized_text, 294 rewritten_text=join_strings(query_terms), 295 query_terms=query_terms^, 296 normalization_signals=normalization_signals^, 297 ranking_hints=ranking_hints^, 298 extracted_filters=ExtractedFilters( 299 local_intent=local_intent, 300 fulfillment=fulfillment, 301 time_window=time_window, 302 ), 303 ) 304 305 306 def analyze_query( 307 input: Value, context: RequestContext, capability_name: String 308 ) raises -> QueryAnalysis: 309 var original_text = extract_text_input(input, capability_name) 310 return analyze_query_text(original_text, context) 311 312 313 def serialize_extracted_filters(filters: ExtractedFilters) raises -> Value: 314 var value = loads("{}") 315 value.set("local_intent", Value(filters.local_intent)) 316 value.set("fulfillment", Value(String(filters.fulfillment))) 317 value.set("time_window", Value(String(filters.time_window))) 318 return value^ 319 320 321 def query_signal_tags(analysis: QueryAnalysis) -> List[String]: 322 var signal_tags = copy_string_list(analysis.normalization_signals) 323 for hint in analysis.ranking_hints: 324 signal_tags.append(String(hint)) 325 return signal_tags^ 326 327 328 def build_deterministic_meta( 329 context: RequestContext, 330 capability_name: String, 331 signal_tags: List[String], 332 extra_source_refs: List[ProvenanceSourceRef], 333 ) -> CoreResponseMeta: 334 var source_refs = List[ProvenanceSourceRef]() 335 source_refs.append( 336 ProvenanceSourceRef( 337 source_kind="local_input", 338 source_ref=capability_name + ":input", 339 ) 340 ) 341 for source_ref in extra_source_refs: 342 source_refs.append( 343 ProvenanceSourceRef( 344 source_kind=String(source_ref.source_kind), 345 source_ref=String(source_ref.source_ref), 346 ) 347 ) 348 if context.scope: 349 source_refs.append( 350 ProvenanceSourceRef( 351 source_kind="request_scope", 352 source_ref="request_context.scope", 353 ) 354 ) 355 356 if context.return_provenance: 357 return CoreResponseMeta( 358 execution_mode="deterministic", 359 backend="heuristic", 360 provider=None, 361 route=None, 362 model=None, 363 latency_ms=None, 364 schema_version=Optional[Int](1), 365 prompt_version=None, 366 fallback_kind=None, 367 fallback_reason=None, 368 provenance=ExecutionProvenance( 369 kind="deterministic", 370 signal_tags=copy_string_list(signal_tags), 371 source_refs=source_refs^, 372 fallback=None, 373 evidence_set_id=None, 374 ), 375 ) 376 377 return CoreResponseMeta( 378 execution_mode="deterministic", 379 backend="heuristic", 380 provider=None, 381 route=None, 382 model=None, 383 latency_ms=None, 384 schema_version=Optional[Int](1), 385 prompt_version=None, 386 fallback_kind=None, 387 fallback_reason=None, 388 provenance=None, 389 )