hyf

Context-aware query service for Radroots
git clone https://radroots.dev/git/hyf.git
Log | Files | Refs | README | LICENSE

test_provider_adapter.mojo (12366B)


      1 from std.testing import TestSuite, assert_equal, assert_raises, assert_true
      2 
      3 from json import Value, loads
      4 
      5 from hyf_assist.contract import max_local_query_rewrite_route
      6 from hyf_core.request_context import default_request_context
      7 from hyf_provider.client import (
      8     get_max_local_health,
      9     max_local_chat_completions_url,
     10     post_max_local_chat_completion,
     11 )
     12 from hyf_provider.config import (
     13     MaxLocalProviderConfig,
     14     max_local_provider_config_from_runtime,
     15 )
     16 from hyf_provider.health import max_local_health_failure_from_reason
     17 from hyf_provider.max_local import max_local_query_rewrite_failure_from_reason
     18 from hyf_provider.result import parse_query_analysis_from_chat_completion
     19 from hyf_provider.schema import build_query_rewrite_request_body
     20 from hyf_runtime.config import (
     21     HyfAssistedRuntimeConfig,
     22     HyfExecutionRuntimeConfig,
     23     HyfLoadedRuntimeConfig,
     24     HyfMaxLocalProviderRuntimeConfig,
     25     HyfRuntimeConfig,
     26     HyfServiceRuntimeConfig,
     27     default_loaded_runtime_config,
     28 )
     29 from max_local_process_helper import (
     30     reserve_loopback_port,
     31     spawn_max_local_stub,
     32 )
     33 
     34 
     35 def _provider_runtime_config() -> HyfLoadedRuntimeConfig:
     36     return HyfLoadedRuntimeConfig(
     37         artifact_present=True,
     38         loaded=True,
     39         compiled_defaults_active=False,
     40         load_state="loaded",
     41         load_error="",
     42         effective=HyfRuntimeConfig(
     43             service=HyfServiceRuntimeConfig(transport="stdio"),
     44             runtime=HyfExecutionRuntimeConfig(
     45                 default_execution_mode="deterministic",
     46                 allow_assisted=True,
     47             ),
     48             assisted=HyfAssistedRuntimeConfig(
     49                 provider="max_local",
     50                 max_local=HyfMaxLocalProviderRuntimeConfig(
     51                     enabled=True,
     52                     base_url="http://127.0.0.1:8000/v1/",
     53                     health_url="http://127.0.0.1:8000/health",
     54                     model="max-local-query-rewrite",
     55                     request_timeout_ms=15000,
     56                 ),
     57             ),
     58         ),
     59     )
     60 
     61 
     62 def _provider_config() -> MaxLocalProviderConfig:
     63     return MaxLocalProviderConfig(
     64         base_url="http://127.0.0.1:8000/v1/",
     65         health_url="http://127.0.0.1:8000/health",
     66         model="max-local-query-rewrite",
     67         request_timeout_ms=15000,
     68     )
     69 
     70 
     71 def _provider_config_for_port(port: Int) -> MaxLocalProviderConfig:
     72     return MaxLocalProviderConfig(
     73         base_url="http://127.0.0.1:" + String(port) + "/v1/",
     74         health_url="http://127.0.0.1:" + String(port) + "/health",
     75         model="max-local-query-rewrite",
     76         request_timeout_ms=15000,
     77     )
     78 
     79 
     80 def _invalid_base_url_provider_config() -> MaxLocalProviderConfig:
     81     return MaxLocalProviderConfig(
     82         base_url="ftp://127.0.0.1:8000/v1/",
     83         health_url="http://127.0.0.1:8000/health",
     84         model="max-local-query-rewrite",
     85         request_timeout_ms=15000,
     86     )
     87 
     88 
     89 def _invalid_health_url_provider_config() -> MaxLocalProviderConfig:
     90     return MaxLocalProviderConfig(
     91         base_url="http://127.0.0.1:8000/v1/",
     92         health_url="ftp://127.0.0.1:8000/health",
     93         model="max-local-query-rewrite",
     94         request_timeout_ms=15000,
     95     )
     96 
     97 
     98 def _analysis_json_text() -> String:
     99     return (
    100         '{"original_text":"eggs near me",'
    101         '"normalized_text":"eggs near me",'
    102         '"rewritten_text":"eggs",'
    103         '"query_terms":["eggs"],'
    104         '"normalization_signals":["local_intent_detected"],'
    105         '"ranking_hints":["prefer_local_results"],'
    106         '"extracted_filters":{'
    107         '"local_intent":true,'
    108         '"fulfillment":"unspecified",'
    109         '"time_window":"unspecified"'
    110         "}}"
    111     )
    112 
    113 
    114 def _chat_completion_response_with_content(content: String) raises -> Value:
    115     var response = loads("{}")
    116     var choices = loads("[]")
    117     var choice = loads("{}")
    118     var message = loads("{}")
    119     message.set("content", Value(content))
    120     choice.set("message", message)
    121     choices.append(choice)
    122     response.set("choices", choices)
    123     return response^
    124 
    125 
    126 def _chat_completion_response() raises -> Value:
    127     return _chat_completion_response_with_content(_analysis_json_text())
    128 
    129 
    130 def _assert_query_rewrite_failure(
    131     reason: String, expected_kind: String, expected_reason: String
    132 ) raises:
    133     var failure = max_local_query_rewrite_failure_from_reason(reason)
    134     assert_equal(failure.kind, expected_kind)
    135     assert_equal(failure.reason, expected_reason)
    136 
    137 
    138 def _assert_health_failure(
    139     reason: String, expected_kind: String, expected_reason: String
    140 ) raises:
    141     var failure = max_local_health_failure_from_reason(reason)
    142     assert_equal(failure.kind, expected_kind)
    143     assert_equal(failure.reason, expected_reason)
    144 
    145 
    146 def _assert_chat_completion_parse_failure(
    147     response: Value, expected_error: String
    148 ) raises:
    149     try:
    150         _ = parse_query_analysis_from_chat_completion(response)
    151     except e:
    152         assert_equal(String(e), expected_error)
    153         return
    154     raise Error("expected chat completion parse failure")
    155 
    156 
    157 def test_provider_config_maps_runtime_config() raises:
    158     var config = max_local_provider_config_from_runtime(
    159         _provider_runtime_config()
    160     )
    161 
    162     assert_equal(config.base_url, "http://127.0.0.1:8000/v1/")
    163     assert_equal(config.health_url, "http://127.0.0.1:8000/health")
    164     assert_equal(config.model, "max-local-query-rewrite")
    165     assert_equal(config.request_timeout_ms, 15000)
    166 
    167 
    168 def test_max_local_route_is_derived_from_assisted_contract() raises:
    169     assert_equal(
    170         max_local_query_rewrite_route(),
    171         "provider_runtime.query_rewrite.max_local",
    172     )
    173 
    174 
    175 def test_max_local_provider_failure_mapping_preserves_reason_tokens() raises:
    176     _assert_query_rewrite_failure("timeout", "transport", "timeout")
    177     _assert_query_rewrite_failure(
    178         "connection_failed", "transport", "connection_failed"
    179     )
    180     _assert_query_rewrite_failure("invalid_url", "transport", "invalid_url")
    181     _assert_query_rewrite_failure(
    182         "provider_non_2xx", "http_status", "provider_non_2xx"
    183     )
    184     _assert_query_rewrite_failure(
    185         "provider_error_payload",
    186         "provider_payload",
    187         "provider_error_payload",
    188     )
    189     _assert_query_rewrite_failure(
    190         "provider_invalid_json",
    191         "provider_payload",
    192         "provider_invalid_json",
    193     )
    194     _assert_query_rewrite_failure(
    195         "provider_schema_invalid",
    196         "provider_payload",
    197         "provider_schema_invalid",
    198     )
    199     _assert_query_rewrite_failure(
    200         "provider_empty_choices",
    201         "provider_payload",
    202         "provider_empty_choices",
    203     )
    204     _assert_query_rewrite_failure(
    205         "provider_missing_content",
    206         "provider_payload",
    207         "provider_missing_content",
    208     )
    209     _assert_query_rewrite_failure(
    210         "unknown_transport", "provider", "provider_error"
    211     )
    212     _assert_query_rewrite_failure(
    213         "unknown_provider", "provider", "provider_error"
    214     )
    215 
    216 
    217 def test_max_local_health_failure_mapping_preserves_reason_tokens() raises:
    218     _assert_health_failure("timeout", "transport", "timeout")
    219     _assert_health_failure("invalid_url", "transport", "invalid_url")
    220     _assert_health_failure(
    221         "connection_failed", "transport", "connection_failed"
    222     )
    223     _assert_health_failure("non_2xx", "http_status", "non_2xx")
    224     _assert_health_failure(
    225         "unknown_transport", "transport", "connection_failed"
    226     )
    227 
    228 
    229 def test_provider_config_rejects_unconfigured_runtime() raises:
    230     with assert_raises():
    231         _ = max_local_provider_config_from_runtime(
    232             default_loaded_runtime_config()
    233         )
    234 
    235 
    236 def test_max_local_chat_completions_url_trims_base_url() raises:
    237     assert_equal(
    238         max_local_chat_completions_url(_provider_config()),
    239         "http://127.0.0.1:8000/v1/chat/completions",
    240     )
    241 
    242 
    243 def test_max_local_transport_boundary_rejects_invalid_chat_url() raises:
    244     var outcome = post_max_local_chat_completion(
    245         _invalid_base_url_provider_config(), loads("{}")
    246     )
    247 
    248     assert_true(outcome.failure)
    249     assert_true(not outcome.response)
    250     assert_equal(outcome.failure.value().kind, "transport")
    251     assert_equal(outcome.failure.value().reason, "invalid_url")
    252 
    253 
    254 def test_max_local_transport_boundary_rejects_invalid_health_url() raises:
    255     var outcome = get_max_local_health(_invalid_health_url_provider_config())
    256 
    257     assert_true(outcome.failure)
    258     assert_true(not outcome.response)
    259     assert_equal(outcome.failure.value().kind, "transport")
    260     assert_equal(outcome.failure.value().reason, "invalid_url")
    261 
    262 
    263 def test_max_local_transport_boundary_reports_unknown_chat_transport() raises:
    264     var provider_port = reserve_loopback_port()
    265     var provider_stub = spawn_max_local_stub(
    266         provider_port, "query_rewrite_malformed_http", 1
    267     )
    268     var outcome = post_max_local_chat_completion(
    269         _provider_config_for_port(provider_port), loads("{}")
    270     )
    271 
    272     assert_true(outcome.failure)
    273     assert_true(not outcome.response)
    274     assert_equal(outcome.failure.value().kind, "transport")
    275     assert_equal(outcome.failure.value().reason, "unknown_transport")
    276 
    277     provider_stub.wait()
    278 
    279 
    280 def test_max_local_transport_boundary_reports_unknown_health_transport() raises:
    281     var provider_port = reserve_loopback_port()
    282     var provider_stub = spawn_max_local_stub(
    283         provider_port, "health_malformed_http", 1
    284     )
    285     var outcome = get_max_local_health(_provider_config_for_port(provider_port))
    286 
    287     assert_true(outcome.failure)
    288     assert_true(not outcome.response)
    289     assert_equal(outcome.failure.value().kind, "transport")
    290     assert_equal(outcome.failure.value().reason, "unknown_transport")
    291 
    292     provider_stub.wait()
    293 
    294 
    295 def test_query_rewrite_request_body_sets_schema_contract() raises:
    296     var context = default_request_context()
    297     context.return_provenance = True
    298     var body = build_query_rewrite_request_body(
    299         _provider_config(), "eggs near me", context
    300     )
    301 
    302     assert_equal(body["model"].string_value(), "max-local-query-rewrite")
    303     assert_equal(body["messages"][0]["role"].string_value(), "system")
    304     assert_equal(body["messages"][1]["role"].string_value(), "user")
    305     assert_true(
    306         body["messages"][1]["content"].string_value().find("eggs near me")
    307         >= 0
    308     )
    309     assert_equal(body["response_format"]["type"].string_value(), "json_schema")
    310     assert_equal(
    311         body["response_format"]["json_schema"]["name"].string_value(),
    312         "query_rewrite",
    313     )
    314     assert_equal(
    315         body["response_format"]["json_schema"]["strict"].bool_value(), True
    316     )
    317     assert_equal(
    318         body["response_format"]["json_schema"]["schema"]["type"]
    319         .string_value(),
    320         "object",
    321     )
    322 
    323 
    324 def test_chat_completion_response_parses_query_analysis() raises:
    325     var analysis = parse_query_analysis_from_chat_completion(
    326         _chat_completion_response()
    327     )
    328 
    329     assert_equal(analysis.original_text, "eggs near me")
    330     assert_equal(analysis.normalized_text, "eggs near me")
    331     assert_equal(analysis.rewritten_text, "eggs")
    332     assert_equal(len(analysis.query_terms), 1)
    333     assert_equal(analysis.query_terms[0], "eggs")
    334     assert_equal(analysis.extracted_filters.local_intent, True)
    335 
    336 
    337 def test_chat_completion_response_rejects_invalid_json_content() raises:
    338     _assert_chat_completion_parse_failure(
    339         _chat_completion_response_with_content("not json"),
    340         "provider_invalid_json",
    341     )
    342 
    343 
    344 def test_chat_completion_response_rejects_schema_invalid_content() raises:
    345     _assert_chat_completion_parse_failure(
    346         _chat_completion_response_with_content('{"original_text":"eggs"}'),
    347         "provider_schema_invalid",
    348     )
    349 
    350 
    351 def test_chat_completion_response_rejects_empty_choices() raises:
    352     with assert_raises():
    353         _ = parse_query_analysis_from_chat_completion(loads('{"choices":[]}'))
    354 
    355 
    356 def test_chat_completion_response_rejects_top_level_scalar() raises:
    357     with assert_raises():
    358         _ = parse_query_analysis_from_chat_completion(loads('"not object"'))
    359 
    360 
    361 def test_chat_completion_response_rejects_top_level_array() raises:
    362     with assert_raises():
    363         _ = parse_query_analysis_from_chat_completion(loads("[]"))
    364 
    365 
    366 def test_chat_completion_response_rejects_top_level_null() raises:
    367     with assert_raises():
    368         _ = parse_query_analysis_from_chat_completion(loads("null"))
    369 
    370 
    371 def main() raises:
    372     TestSuite.discover_tests[__functions_in_module()]().run()