test_provider_adapter.mojo (12366B)
1 from std.testing import TestSuite, assert_equal, assert_raises, assert_true 2 3 from json import Value, loads 4 5 from hyf_assist.contract import max_local_query_rewrite_route 6 from hyf_core.request_context import default_request_context 7 from hyf_provider.client import ( 8 get_max_local_health, 9 max_local_chat_completions_url, 10 post_max_local_chat_completion, 11 ) 12 from hyf_provider.config import ( 13 MaxLocalProviderConfig, 14 max_local_provider_config_from_runtime, 15 ) 16 from hyf_provider.health import max_local_health_failure_from_reason 17 from hyf_provider.max_local import max_local_query_rewrite_failure_from_reason 18 from hyf_provider.result import parse_query_analysis_from_chat_completion 19 from hyf_provider.schema import build_query_rewrite_request_body 20 from hyf_runtime.config import ( 21 HyfAssistedRuntimeConfig, 22 HyfExecutionRuntimeConfig, 23 HyfLoadedRuntimeConfig, 24 HyfMaxLocalProviderRuntimeConfig, 25 HyfRuntimeConfig, 26 HyfServiceRuntimeConfig, 27 default_loaded_runtime_config, 28 ) 29 from max_local_process_helper import ( 30 reserve_loopback_port, 31 spawn_max_local_stub, 32 ) 33 34 35 def _provider_runtime_config() -> HyfLoadedRuntimeConfig: 36 return HyfLoadedRuntimeConfig( 37 artifact_present=True, 38 loaded=True, 39 compiled_defaults_active=False, 40 load_state="loaded", 41 load_error="", 42 effective=HyfRuntimeConfig( 43 service=HyfServiceRuntimeConfig(transport="stdio"), 44 runtime=HyfExecutionRuntimeConfig( 45 default_execution_mode="deterministic", 46 allow_assisted=True, 47 ), 48 assisted=HyfAssistedRuntimeConfig( 49 provider="max_local", 50 max_local=HyfMaxLocalProviderRuntimeConfig( 51 enabled=True, 52 base_url="http://127.0.0.1:8000/v1/", 53 health_url="http://127.0.0.1:8000/health", 54 model="max-local-query-rewrite", 55 request_timeout_ms=15000, 56 ), 57 ), 58 ), 59 ) 60 61 62 def _provider_config() -> MaxLocalProviderConfig: 63 return MaxLocalProviderConfig( 64 base_url="http://127.0.0.1:8000/v1/", 65 health_url="http://127.0.0.1:8000/health", 66 model="max-local-query-rewrite", 67 request_timeout_ms=15000, 68 ) 69 70 71 def _provider_config_for_port(port: Int) -> MaxLocalProviderConfig: 72 return MaxLocalProviderConfig( 73 base_url="http://127.0.0.1:" + String(port) + "/v1/", 74 health_url="http://127.0.0.1:" + String(port) + "/health", 75 model="max-local-query-rewrite", 76 request_timeout_ms=15000, 77 ) 78 79 80 def _invalid_base_url_provider_config() -> MaxLocalProviderConfig: 81 return MaxLocalProviderConfig( 82 base_url="ftp://127.0.0.1:8000/v1/", 83 health_url="http://127.0.0.1:8000/health", 84 model="max-local-query-rewrite", 85 request_timeout_ms=15000, 86 ) 87 88 89 def _invalid_health_url_provider_config() -> MaxLocalProviderConfig: 90 return MaxLocalProviderConfig( 91 base_url="http://127.0.0.1:8000/v1/", 92 health_url="ftp://127.0.0.1:8000/health", 93 model="max-local-query-rewrite", 94 request_timeout_ms=15000, 95 ) 96 97 98 def _analysis_json_text() -> String: 99 return ( 100 '{"original_text":"eggs near me",' 101 '"normalized_text":"eggs near me",' 102 '"rewritten_text":"eggs",' 103 '"query_terms":["eggs"],' 104 '"normalization_signals":["local_intent_detected"],' 105 '"ranking_hints":["prefer_local_results"],' 106 '"extracted_filters":{' 107 '"local_intent":true,' 108 '"fulfillment":"unspecified",' 109 '"time_window":"unspecified"' 110 "}}" 111 ) 112 113 114 def _chat_completion_response_with_content(content: String) raises -> Value: 115 var response = loads("{}") 116 var choices = loads("[]") 117 var choice = loads("{}") 118 var message = loads("{}") 119 message.set("content", Value(content)) 120 choice.set("message", message) 121 choices.append(choice) 122 response.set("choices", choices) 123 return response^ 124 125 126 def _chat_completion_response() raises -> Value: 127 return _chat_completion_response_with_content(_analysis_json_text()) 128 129 130 def _assert_query_rewrite_failure( 131 reason: String, expected_kind: String, expected_reason: String 132 ) raises: 133 var failure = max_local_query_rewrite_failure_from_reason(reason) 134 assert_equal(failure.kind, expected_kind) 135 assert_equal(failure.reason, expected_reason) 136 137 138 def _assert_health_failure( 139 reason: String, expected_kind: String, expected_reason: String 140 ) raises: 141 var failure = max_local_health_failure_from_reason(reason) 142 assert_equal(failure.kind, expected_kind) 143 assert_equal(failure.reason, expected_reason) 144 145 146 def _assert_chat_completion_parse_failure( 147 response: Value, expected_error: String 148 ) raises: 149 try: 150 _ = parse_query_analysis_from_chat_completion(response) 151 except e: 152 assert_equal(String(e), expected_error) 153 return 154 raise Error("expected chat completion parse failure") 155 156 157 def test_provider_config_maps_runtime_config() raises: 158 var config = max_local_provider_config_from_runtime( 159 _provider_runtime_config() 160 ) 161 162 assert_equal(config.base_url, "http://127.0.0.1:8000/v1/") 163 assert_equal(config.health_url, "http://127.0.0.1:8000/health") 164 assert_equal(config.model, "max-local-query-rewrite") 165 assert_equal(config.request_timeout_ms, 15000) 166 167 168 def test_max_local_route_is_derived_from_assisted_contract() raises: 169 assert_equal( 170 max_local_query_rewrite_route(), 171 "provider_runtime.query_rewrite.max_local", 172 ) 173 174 175 def test_max_local_provider_failure_mapping_preserves_reason_tokens() raises: 176 _assert_query_rewrite_failure("timeout", "transport", "timeout") 177 _assert_query_rewrite_failure( 178 "connection_failed", "transport", "connection_failed" 179 ) 180 _assert_query_rewrite_failure("invalid_url", "transport", "invalid_url") 181 _assert_query_rewrite_failure( 182 "provider_non_2xx", "http_status", "provider_non_2xx" 183 ) 184 _assert_query_rewrite_failure( 185 "provider_error_payload", 186 "provider_payload", 187 "provider_error_payload", 188 ) 189 _assert_query_rewrite_failure( 190 "provider_invalid_json", 191 "provider_payload", 192 "provider_invalid_json", 193 ) 194 _assert_query_rewrite_failure( 195 "provider_schema_invalid", 196 "provider_payload", 197 "provider_schema_invalid", 198 ) 199 _assert_query_rewrite_failure( 200 "provider_empty_choices", 201 "provider_payload", 202 "provider_empty_choices", 203 ) 204 _assert_query_rewrite_failure( 205 "provider_missing_content", 206 "provider_payload", 207 "provider_missing_content", 208 ) 209 _assert_query_rewrite_failure( 210 "unknown_transport", "provider", "provider_error" 211 ) 212 _assert_query_rewrite_failure( 213 "unknown_provider", "provider", "provider_error" 214 ) 215 216 217 def test_max_local_health_failure_mapping_preserves_reason_tokens() raises: 218 _assert_health_failure("timeout", "transport", "timeout") 219 _assert_health_failure("invalid_url", "transport", "invalid_url") 220 _assert_health_failure( 221 "connection_failed", "transport", "connection_failed" 222 ) 223 _assert_health_failure("non_2xx", "http_status", "non_2xx") 224 _assert_health_failure( 225 "unknown_transport", "transport", "connection_failed" 226 ) 227 228 229 def test_provider_config_rejects_unconfigured_runtime() raises: 230 with assert_raises(): 231 _ = max_local_provider_config_from_runtime( 232 default_loaded_runtime_config() 233 ) 234 235 236 def test_max_local_chat_completions_url_trims_base_url() raises: 237 assert_equal( 238 max_local_chat_completions_url(_provider_config()), 239 "http://127.0.0.1:8000/v1/chat/completions", 240 ) 241 242 243 def test_max_local_transport_boundary_rejects_invalid_chat_url() raises: 244 var outcome = post_max_local_chat_completion( 245 _invalid_base_url_provider_config(), loads("{}") 246 ) 247 248 assert_true(outcome.failure) 249 assert_true(not outcome.response) 250 assert_equal(outcome.failure.value().kind, "transport") 251 assert_equal(outcome.failure.value().reason, "invalid_url") 252 253 254 def test_max_local_transport_boundary_rejects_invalid_health_url() raises: 255 var outcome = get_max_local_health(_invalid_health_url_provider_config()) 256 257 assert_true(outcome.failure) 258 assert_true(not outcome.response) 259 assert_equal(outcome.failure.value().kind, "transport") 260 assert_equal(outcome.failure.value().reason, "invalid_url") 261 262 263 def test_max_local_transport_boundary_reports_unknown_chat_transport() raises: 264 var provider_port = reserve_loopback_port() 265 var provider_stub = spawn_max_local_stub( 266 provider_port, "query_rewrite_malformed_http", 1 267 ) 268 var outcome = post_max_local_chat_completion( 269 _provider_config_for_port(provider_port), loads("{}") 270 ) 271 272 assert_true(outcome.failure) 273 assert_true(not outcome.response) 274 assert_equal(outcome.failure.value().kind, "transport") 275 assert_equal(outcome.failure.value().reason, "unknown_transport") 276 277 provider_stub.wait() 278 279 280 def test_max_local_transport_boundary_reports_unknown_health_transport() raises: 281 var provider_port = reserve_loopback_port() 282 var provider_stub = spawn_max_local_stub( 283 provider_port, "health_malformed_http", 1 284 ) 285 var outcome = get_max_local_health(_provider_config_for_port(provider_port)) 286 287 assert_true(outcome.failure) 288 assert_true(not outcome.response) 289 assert_equal(outcome.failure.value().kind, "transport") 290 assert_equal(outcome.failure.value().reason, "unknown_transport") 291 292 provider_stub.wait() 293 294 295 def test_query_rewrite_request_body_sets_schema_contract() raises: 296 var context = default_request_context() 297 context.return_provenance = True 298 var body = build_query_rewrite_request_body( 299 _provider_config(), "eggs near me", context 300 ) 301 302 assert_equal(body["model"].string_value(), "max-local-query-rewrite") 303 assert_equal(body["messages"][0]["role"].string_value(), "system") 304 assert_equal(body["messages"][1]["role"].string_value(), "user") 305 assert_true( 306 body["messages"][1]["content"].string_value().find("eggs near me") 307 >= 0 308 ) 309 assert_equal(body["response_format"]["type"].string_value(), "json_schema") 310 assert_equal( 311 body["response_format"]["json_schema"]["name"].string_value(), 312 "query_rewrite", 313 ) 314 assert_equal( 315 body["response_format"]["json_schema"]["strict"].bool_value(), True 316 ) 317 assert_equal( 318 body["response_format"]["json_schema"]["schema"]["type"] 319 .string_value(), 320 "object", 321 ) 322 323 324 def test_chat_completion_response_parses_query_analysis() raises: 325 var analysis = parse_query_analysis_from_chat_completion( 326 _chat_completion_response() 327 ) 328 329 assert_equal(analysis.original_text, "eggs near me") 330 assert_equal(analysis.normalized_text, "eggs near me") 331 assert_equal(analysis.rewritten_text, "eggs") 332 assert_equal(len(analysis.query_terms), 1) 333 assert_equal(analysis.query_terms[0], "eggs") 334 assert_equal(analysis.extracted_filters.local_intent, True) 335 336 337 def test_chat_completion_response_rejects_invalid_json_content() raises: 338 _assert_chat_completion_parse_failure( 339 _chat_completion_response_with_content("not json"), 340 "provider_invalid_json", 341 ) 342 343 344 def test_chat_completion_response_rejects_schema_invalid_content() raises: 345 _assert_chat_completion_parse_failure( 346 _chat_completion_response_with_content('{"original_text":"eggs"}'), 347 "provider_schema_invalid", 348 ) 349 350 351 def test_chat_completion_response_rejects_empty_choices() raises: 352 with assert_raises(): 353 _ = parse_query_analysis_from_chat_completion(loads('{"choices":[]}')) 354 355 356 def test_chat_completion_response_rejects_top_level_scalar() raises: 357 with assert_raises(): 358 _ = parse_query_analysis_from_chat_completion(loads('"not object"')) 359 360 361 def test_chat_completion_response_rejects_top_level_array() raises: 362 with assert_raises(): 363 _ = parse_query_analysis_from_chat_completion(loads("[]")) 364 365 366 def test_chat_completion_response_rejects_top_level_null() raises: 367 with assert_raises(): 368 _ = parse_query_analysis_from_chat_completion(loads("null")) 369 370 371 def main() raises: 372 TestSuite.discover_tests[__functions_in_module()]().run()