commit 79a7fbf061112916743c2754164f5fae4ec89ea8
parent 08b5a2a260747796d08fa4277a761b688bf0da98
Author: triesap <tyson@radroots.org>
Date: Wed, 8 Apr 2026 19:51:51 +0000
capabilities: type query rewrite input contract
Diffstat:
3 files changed, 114 insertions(+), 8 deletions(-)
diff --git a/src/hyf_core/capabilities/query_analysis.mojo b/src/hyf_core/capabilities/query_analysis.mojo
@@ -10,6 +10,19 @@ from hyf_core.provenance import (
from hyf_core.request_context import RequestContext
+def _require_object(value: Value, context: String) raises:
+ if not value.is_object():
+ raise Error(context + " must be a JSON object")
+
+
+def _require_allowed_keys(
+ value: Value, key_a: String, key_b: String, context: String
+) raises:
+ for key in value.object_keys():
+ if key != key_a and key != key_b:
+ raise Error(context + " contains unexpected field '" + key + "'")
+
+
def has_key(value: Value, key: String) -> Bool:
for candidate in value.object_keys():
if candidate == key:
@@ -130,6 +143,11 @@ struct QueryAnalysis(Copyable, Movable):
var extracted_filters: ExtractedFilters
+@fieldwise_init
+struct QueryRewriteRequest(Copyable, Movable):
+ var text: String
+
+
def extract_text_input(input: Value, capability_name: String) raises -> String:
if not input.is_object():
raise Error(capability_name + " input must be a JSON object")
@@ -160,13 +178,45 @@ def extract_text_input(input: Value, capability_name: String) raises -> String:
)
-def analyze_query(
- input: Value, context: RequestContext, capability_name: String
-) raises -> QueryAnalysis:
- var original_text = extract_text_input(input, capability_name)
+def parse_query_rewrite_request(input: Value) raises -> QueryRewriteRequest:
+ _require_object(input, "query_rewrite input")
+ _require_allowed_keys(input, "text", "query", "query_rewrite input")
+
+ var has_text = has_key(input, "text")
+ var has_query = has_key(input, "query")
+
+ if has_text and has_query:
+ raise Error(
+ "query_rewrite input must provide exactly one of 'text' or 'query'"
+ )
+ if not has_text and not has_query:
+ raise Error(
+ "query_rewrite input requires exactly one of 'text' or 'query'"
+ )
+
+ var source_field = "text" if has_text else "query"
+ var text_value = input[source_field]
+ if not text_value.is_string():
+ raise Error(
+ "query_rewrite input field '" + source_field + "' must be a string"
+ )
+
+ var collapsed = collapse_whitespace(text_value.string_value())
+ if collapsed == "":
+ raise Error("query_rewrite input text must not be empty")
+
+ return QueryRewriteRequest(text=collapsed)
+
+
+def analyze_query_text(
+ original_text: String, context: RequestContext
+) -> QueryAnalysis:
+ var normalized_input = String(original_text)
var normalization_signals = List[String]()
- var normalized_text = normalize_free_text(original_text, normalization_signals)
+ var normalized_text = normalize_free_text(
+ normalized_input, normalization_signals
+ )
var normalized_tokens = normalized_text.split()
var query_terms = List[String]()
@@ -239,7 +289,7 @@ def analyze_query(
normalization_signals.append("fallback_to_normalized_query")
return QueryAnalysis(
- original_text=original_text,
+ original_text=normalized_input,
normalized_text=normalized_text,
rewritten_text=join_strings(query_terms),
query_terms=query_terms^,
@@ -253,6 +303,13 @@ def analyze_query(
)
+def analyze_query(
+ input: Value, context: RequestContext, capability_name: String
+) raises -> QueryAnalysis:
+ var original_text = extract_text_input(input, capability_name)
+ return analyze_query_text(original_text, context)
+
+
def serialize_extracted_filters(filters: ExtractedFilters) raises -> Value:
var value = loads("{}")
value.set("local_intent", Value(filters.local_intent))
diff --git a/src/hyf_core/capabilities/query_rewrite.mojo b/src/hyf_core/capabilities/query_rewrite.mojo
@@ -4,8 +4,10 @@ from mojson import Value, loads
from hyf_core.capabilities.query_analysis import (
QueryAnalysis,
- analyze_query,
+ QueryRewriteRequest,
+ analyze_query_text,
build_deterministic_meta,
+ parse_query_rewrite_request,
query_signal_tags,
serialize_extracted_filters,
string_array_value,
@@ -51,7 +53,8 @@ def execute_query_rewrite(
)
try:
- var analysis = analyze_query(input, context, "query_rewrite")
+ var request: QueryRewriteRequest = parse_query_rewrite_request(input)
+ var analysis = analyze_query_text(request.text, context)
var source_refs = List[ProvenanceSourceRef]()
return successful_capability(
diff --git a/tests/test_hyf.mojo b/tests/test_hyf.mojo
@@ -230,6 +230,52 @@ def test_query_rewrite_returns_deterministic_output() raises:
assert_equal(result["meta"]["backend"].string_value(), "heuristic")
+def test_query_rewrite_accepts_query_alias_with_same_behavior() raises:
+ var result = _dispatch(
+ '{"version":1,"request_id":"rewrite-query-1","capability":"query_rewrite","input":{"query":"eggs near me with weekend pickup"}}'
+ )
+
+ assert_equal(Int(result["version"].int_value()), 1)
+ assert_equal(result["ok"].bool_value(), True)
+ assert_equal(
+ result["output"]["rewritten_text"].string_value(),
+ "eggs",
+ )
+ assert_equal(
+ result["output"]["extracted_filters"]["fulfillment"].string_value(),
+ "pickup",
+ )
+
+
+def test_query_rewrite_rejects_unknown_input_field() raises:
+ var result = _dispatch(
+ '{"version":1,"request_id":"rewrite-bad-field-1","capability":"query_rewrite","input":{"text":"eggs near me","tone":"brief"}}'
+ )
+
+ assert_equal(Int(result["version"].int_value()), 1)
+ assert_equal(result["ok"].bool_value(), False)
+ assert_equal(result["request_id"].string_value(), "rewrite-bad-field-1")
+ assert_equal(result["error"]["code"].string_value(), "invalid_request")
+ assert_true(
+ result["error"]["message"].string_value().find("unexpected field")
+ >= 0
+ )
+
+
+def test_query_rewrite_rejects_text_and_query_together() raises:
+ var result = _dispatch(
+ '{"version":1,"request_id":"rewrite-bad-dual-1","capability":"query_rewrite","input":{"text":"eggs near me","query":"eggs"}}'
+ )
+
+ assert_equal(Int(result["version"].int_value()), 1)
+ assert_equal(result["ok"].bool_value(), False)
+ assert_equal(result["request_id"].string_value(), "rewrite-bad-dual-1")
+ assert_equal(result["error"]["code"].string_value(), "invalid_request")
+ assert_true(
+ result["error"]["message"].string_value().find("exactly one") >= 0
+ )
+
+
def test_semantic_rank_returns_ranked_ids_and_reasons() raises:
var result = _dispatch(
'{"version":1,"request_id":"rank-1","capability":"semantic_rank","input":{"query":"eggs near me with weekend pickup","candidates":[{"id":"lst_7ak2","title":"Pasture eggs","farm":"La Huerta del Sur","delivery":"pickup","distance_km":3.2,"freshness_minutes":2},{"id":"lst_8k1p","title":"Free range eggs","farm":"Santa Elena","delivery":"delivery","distance_km":8.7,"freshness_minutes":18}]}}'