commit 82264d575e514abaab89715d08bfd39d93dd488f
parent 9d9ace6e7670faccb5f3a0223364f970d82c4813
Author: triesap <tyson@radroots.org>
Date: Thu, 18 Jun 2026 20:37:01 -0700
provider: isolate max local transport boundary
Diffstat:
4 files changed, 244 insertions(+), 96 deletions(-)
diff --git a/src/hyf_provider/client.mojo b/src/hyf_provider/client.mojo
@@ -1,17 +1,133 @@
+from std.collections import Optional
+from std.time import perf_counter_ns
+
+from json import Value
from flare.http import HttpClient
from hyf_provider.config import MaxLocalProviderConfig
+@fieldwise_init
+struct MaxLocalTransportResponse(Copyable, Movable):
+ var status: Int
+ var body_text: String
+ var latency_ms: Int
+
+
+@fieldwise_init
+struct MaxLocalTransportFailure(Copyable, Movable):
+ var kind: String
+ var reason: String
+
+
+@fieldwise_init
+struct MaxLocalTransportOutcome(Copyable, Movable):
+ var response: Optional[MaxLocalTransportResponse]
+ var failure: Optional[MaxLocalTransportFailure]
+
+
def _trim_trailing_slash(url: String) -> String:
if url.endswith("/") and url.byte_length() > 1:
return String(url[byte = 0 : url.byte_length() - 1])
return String(url)
+def _http_url(url: String) -> Bool:
+ return url.startswith("http://") or url.startswith("https://")
+
+
+def _elapsed_ms_since(start_ns: UInt) -> Int:
+ return Int((perf_counter_ns() - start_ns) // 1_000_000)
+
+
+def _transport_response_outcome(
+ status: Int, body_text: String, latency_ms: Int
+) -> MaxLocalTransportOutcome:
+ return MaxLocalTransportOutcome(
+ response=Optional[MaxLocalTransportResponse](
+ MaxLocalTransportResponse(
+ status=status,
+ body_text=String(body_text),
+ latency_ms=latency_ms,
+ )
+ ),
+ failure=Optional[MaxLocalTransportFailure](None),
+ )
+
+
+def _transport_failure_outcome(
+ kind: String, reason: String
+) -> MaxLocalTransportOutcome:
+ return MaxLocalTransportOutcome(
+ response=Optional[MaxLocalTransportResponse](None),
+ failure=Optional[MaxLocalTransportFailure](
+ MaxLocalTransportFailure(
+ kind=String(kind), reason=String(reason)
+ )
+ ),
+ )
+
+
+def _transport_exception_reason(
+ start_ns: UInt, request_timeout_ms: Int
+) -> String:
+ if _elapsed_ms_since(start_ns) >= request_timeout_ms:
+ return "timeout"
+ return "connection_failed"
+
+
def make_max_local_http_client(config: MaxLocalProviderConfig) -> HttpClient:
return HttpClient(timeout_ms=config.request_timeout_ms)
def max_local_chat_completions_url(config: MaxLocalProviderConfig) -> String:
return _trim_trailing_slash(config.base_url) + "/chat/completions"
+
+
+def get_max_local_health(
+ config: MaxLocalProviderConfig,
+) -> MaxLocalTransportOutcome:
+ if not _http_url(config.health_url):
+ return _transport_failure_outcome("transport", "invalid_url")
+
+ var start_ns = perf_counter_ns()
+ try:
+ with make_max_local_http_client(config) as client:
+ var response = client.get(config.health_url)
+ var latency_ms = _elapsed_ms_since(start_ns)
+ if not response.ok():
+ return _transport_failure_outcome("http_status", "non_2xx")
+ return _transport_response_outcome(
+ response.status, response.text(), latency_ms
+ )
+ except:
+ return _transport_failure_outcome(
+ "transport",
+ _transport_exception_reason(start_ns, config.request_timeout_ms),
+ )
+
+
+def post_max_local_chat_completion(
+ config: MaxLocalProviderConfig, body: Value
+) -> MaxLocalTransportOutcome:
+ var url = max_local_chat_completions_url(config)
+ if not _http_url(url):
+ return _transport_failure_outcome("transport", "invalid_url")
+
+ var start_ns = perf_counter_ns()
+ try:
+ with make_max_local_http_client(config) as client:
+ var response = client.post(url, body)
+ var latency_ms = _elapsed_ms_since(start_ns)
+ if not response.ok():
+ return _transport_failure_outcome(
+ "http_status", "provider_non_2xx"
+ )
+ return _transport_response_outcome(
+ response.status, response.text(), latency_ms
+ )
+ except:
+ return _transport_failure_outcome(
+ "transport",
+ _transport_exception_reason(start_ns, config.request_timeout_ms),
+ )
diff --git a/src/hyf_provider/health.mojo b/src/hyf_provider/health.mojo
@@ -1,7 +1,5 @@
-from std.time import perf_counter_ns
-
from hyf_assist.contract import max_local_query_rewrite_route
-from hyf_provider.client import make_max_local_http_client
+from hyf_provider.client import get_max_local_health
from hyf_provider.config import MaxLocalProviderConfig
from hyf_provider.result import MaxLocalProviderStatus
@@ -33,44 +31,41 @@ def _provider_status(
)
-def _elapsed_ms_since(start_ns: UInt) -> Int:
- return Int((perf_counter_ns() - start_ns) // 1_000_000)
-
-
-def max_local_health_failure_from_error(
- message: String,
+def max_local_health_failure_from_reason(
+ reason: String,
) -> MaxLocalHealthFailure:
- if message == "invalid_url":
+ if reason == "invalid_url":
return _health_failure("transport", "invalid_url")
- if message == "timeout":
+ if reason == "timeout":
return _health_failure("transport", "timeout")
- if message == "connection_failed":
+ if reason == "connection_failed":
return _health_failure("transport", "connection_failed")
+ if reason == "non_2xx":
+ return _health_failure("http_status", "non_2xx")
return _health_failure("transport", "connection_failed")
def resolve_max_local_provider_status(
config: MaxLocalProviderConfig,
) -> MaxLocalProviderStatus:
- var start_ns = perf_counter_ns()
- try:
- with make_max_local_http_client(config) as client:
- var response = client.get(config.health_url)
- if response.ok():
- return _provider_status(config, True, "ready", "ready")
- return _provider_status(config, False, "unavailable", "non_2xx")
- except e:
- var failure = max_local_health_failure_from_error(String(e))
- if (
- failure.reason == "connection_failed"
- and _elapsed_ms_since(start_ns) >= config.request_timeout_ms
- ):
- failure = MaxLocalHealthFailure(
- kind="transport", reason="timeout"
- )
+ var transport = get_max_local_health(config)
+ if transport.response:
+ return _provider_status(config, True, "ready", "ready")
+
+ if transport.failure:
+ var failure = max_local_health_failure_from_reason(
+ transport.failure.value().reason
+ )
return _provider_status(
config,
False,
"unavailable",
String(failure.reason),
)
+
+ return _provider_status(
+ config,
+ False,
+ "unavailable",
+ "connection_failed",
+ )
diff --git a/src/hyf_provider/max_local.mojo b/src/hyf_provider/max_local.mojo
@@ -1,13 +1,11 @@
from std.collections import Optional
-from std.time import perf_counter_ns
+
+from json import Value, loads
from hyf_assist.contract import max_local_query_rewrite_route
from hyf_core.capabilities.query_analysis import QueryAnalysis
from hyf_core.request_context import RequestContext
-from hyf_provider.client import (
- make_max_local_http_client,
- max_local_chat_completions_url,
-)
+from hyf_provider.client import post_max_local_chat_completion
from hyf_provider.config import MaxLocalProviderConfig
from hyf_provider.health import resolve_max_local_provider_status
from hyf_provider.result import (
@@ -66,46 +64,42 @@ def _query_rewrite_failure_outcome(
)
-def _elapsed_ms_since(start_ns: UInt) -> Int:
- return Int((perf_counter_ns() - start_ns) // 1_000_000)
-
-
-def max_local_query_rewrite_failure_from_error(
- message: String,
+def max_local_query_rewrite_failure_from_reason(
+ reason: String,
) -> MaxLocalQueryRewriteFailure:
- if message == "invalid_url":
+ if reason == "invalid_url":
return MaxLocalQueryRewriteFailure(
kind="transport", reason="invalid_url"
)
- if message == "timeout":
+ if reason == "timeout":
return MaxLocalQueryRewriteFailure(
kind="transport", reason="timeout"
)
- if message == "connection_failed":
+ if reason == "connection_failed":
return MaxLocalQueryRewriteFailure(
kind="transport", reason="connection_failed"
)
- if message == "provider_non_2xx":
+ if reason == "provider_non_2xx":
return MaxLocalQueryRewriteFailure(
kind="http_status", reason="provider_non_2xx"
)
- if message == "provider_error_payload":
+ if reason == "provider_error_payload":
return MaxLocalQueryRewriteFailure(
kind="provider_payload", reason="provider_error_payload"
)
- if message == "provider_invalid_json":
+ if reason == "provider_invalid_json":
return MaxLocalQueryRewriteFailure(
kind="provider_payload", reason="provider_invalid_json"
)
- if message == "provider_schema_invalid":
+ if reason == "provider_schema_invalid":
return MaxLocalQueryRewriteFailure(
kind="provider_payload", reason="provider_schema_invalid"
)
- if message == "provider_empty_choices":
+ if reason == "provider_empty_choices":
return MaxLocalQueryRewriteFailure(
kind="provider_payload", reason="provider_empty_choices"
)
- if message == "provider_missing_content":
+ if reason == "provider_missing_content":
return MaxLocalQueryRewriteFailure(
kind="provider_payload", reason="provider_missing_content"
)
@@ -114,6 +108,19 @@ def max_local_query_rewrite_failure_from_error(
)
+def _load_chat_completion_response_json(text: String) raises -> Value:
+ try:
+ return loads(text)
+ except:
+ raise Error("provider_invalid_json")
+
+
+def _parse_query_analysis_from_body(text: String) raises -> QueryAnalysis:
+ return parse_query_analysis_from_chat_completion(
+ _load_chat_completion_response_json(text)
+ )
+
+
def execute_query_rewrite_via_max_local_provider(
config: MaxLocalProviderConfig, text: String, context: RequestContext
) raises -> MaxLocalQueryRewriteResult:
@@ -130,50 +137,49 @@ def execute_query_rewrite_via_max_local_provider(
def try_execute_query_rewrite_via_max_local_provider(
config: MaxLocalProviderConfig, text: String, context: RequestContext
) -> MaxLocalQueryRewriteOutcome:
- with make_max_local_http_client(config) as client:
- var start_ns = perf_counter_ns()
- try:
- var response = client.post(
- max_local_chat_completions_url(config),
- build_query_rewrite_request_body(config, text, context),
- )
- var latency_ms = Int(
- (perf_counter_ns() - start_ns) // 1_000_000
- )
- if not response.ok():
- return _query_rewrite_failure_outcome(
- "http_status", "provider_non_2xx"
- )
+ var request_body: Value
+ try:
+ request_body = build_query_rewrite_request_body(config, text, context)
+ except:
+ return _query_rewrite_failure_outcome("provider", "provider_error")
+
+ var transport = post_max_local_chat_completion(
+ config,
+ request_body^,
+ )
+ if transport.failure:
+ var failure = max_local_query_rewrite_failure_from_reason(
+ transport.failure.value().reason
+ )
+ return _query_rewrite_failure_outcome(
+ String(failure.kind), String(failure.reason)
+ )
- var analysis = parse_query_analysis_from_chat_completion(
- response.json()
- )
+ if transport.response:
+ try:
+ var response = transport.response.value().copy()
+ var analysis = _parse_query_analysis_from_body(response.body_text)
return _query_rewrite_success_outcome(
MaxLocalQueryRewriteResult(
analysis=analysis^,
provider="max_local",
route=max_local_query_rewrite_route(),
model=String(config.model),
- latency_ms=latency_ms,
+ latency_ms=response.latency_ms,
schema_version=query_rewrite_schema_version(),
prompt_version=query_rewrite_prompt_version(),
)
)
except e:
- var failure = max_local_query_rewrite_failure_from_error(
+ var failure = max_local_query_rewrite_failure_from_reason(
String(e)
)
- if (
- failure.reason == "provider_error"
- and _elapsed_ms_since(start_ns) >= config.request_timeout_ms
- ):
- failure = MaxLocalQueryRewriteFailure(
- kind="transport", reason="timeout"
- )
return _query_rewrite_failure_outcome(
String(failure.kind), String(failure.reason)
)
+ return _query_rewrite_failure_outcome("provider", "provider_error")
+
def max_local_provider_status(
config: MaxLocalProviderConfig,
diff --git a/tests/test_provider_adapter.mojo b/tests/test_provider_adapter.mojo
@@ -4,13 +4,17 @@ from json import Value, loads
from hyf_assist.contract import max_local_query_rewrite_route
from hyf_core.request_context import default_request_context
-from hyf_provider.client import max_local_chat_completions_url
+from hyf_provider.client import (
+ get_max_local_health,
+ max_local_chat_completions_url,
+ post_max_local_chat_completion,
+)
from hyf_provider.config import (
MaxLocalProviderConfig,
max_local_provider_config_from_runtime,
)
-from hyf_provider.health import max_local_health_failure_from_error
-from hyf_provider.max_local import max_local_query_rewrite_failure_from_error
+from hyf_provider.health import max_local_health_failure_from_reason
+from hyf_provider.max_local import max_local_query_rewrite_failure_from_reason
from hyf_provider.result import parse_query_analysis_from_chat_completion
from hyf_provider.schema import build_query_rewrite_request_body
from hyf_runtime.config import (
@@ -60,6 +64,24 @@ def _provider_config() -> MaxLocalProviderConfig:
)
+def _invalid_base_url_provider_config() -> MaxLocalProviderConfig:
+ return MaxLocalProviderConfig(
+ base_url="ftp://127.0.0.1:8000/v1/",
+ health_url="http://127.0.0.1:8000/health",
+ model="max-local-query-rewrite",
+ request_timeout_ms=15000,
+ )
+
+
+def _invalid_health_url_provider_config() -> MaxLocalProviderConfig:
+ return MaxLocalProviderConfig(
+ base_url="http://127.0.0.1:8000/v1/",
+ health_url="ftp://127.0.0.1:8000/health",
+ model="max-local-query-rewrite",
+ request_timeout_ms=15000,
+ )
+
+
def _analysis_json_text() -> String:
return (
'{"original_text":"eggs near me",'
@@ -93,17 +115,17 @@ def _chat_completion_response() raises -> Value:
def _assert_query_rewrite_failure(
- message: String, expected_kind: String, expected_reason: String
+ reason: String, expected_kind: String, expected_reason: String
) raises:
- var failure = max_local_query_rewrite_failure_from_error(message)
+ var failure = max_local_query_rewrite_failure_from_reason(reason)
assert_equal(failure.kind, expected_kind)
assert_equal(failure.reason, expected_reason)
def _assert_health_failure(
- message: String, expected_kind: String, expected_reason: String
+ reason: String, expected_kind: String, expected_reason: String
) raises:
- var failure = max_local_health_failure_from_error(message)
+ var failure = max_local_health_failure_from_reason(reason)
assert_equal(failure.kind, expected_kind)
assert_equal(failure.reason, expected_reason)
@@ -172,19 +194,10 @@ def test_max_local_provider_failure_mapping_preserves_reason_tokens() raises:
"provider_missing_content",
)
_assert_query_rewrite_failure(
- "unexpected provider failure", "provider", "provider_error"
- )
- _assert_query_rewrite_failure(
- "timed out", "provider", "provider_error"
- )
- _assert_query_rewrite_failure(
- "connection refused", "provider", "provider_error"
- )
- _assert_query_rewrite_failure(
- "bad url scheme", "provider", "provider_error"
+ "unknown_transport", "provider", "provider_error"
)
_assert_query_rewrite_failure(
- "not a timeout", "provider", "provider_error"
+ "unknown_provider", "provider", "provider_error"
)
@@ -194,12 +207,10 @@ def test_max_local_health_failure_mapping_preserves_reason_tokens() raises:
_assert_health_failure(
"connection_failed", "transport", "connection_failed"
)
+ _assert_health_failure("non_2xx", "http_status", "non_2xx")
_assert_health_failure(
- "unexpected health failure", "transport", "connection_failed"
+ "unknown_transport", "transport", "connection_failed"
)
- _assert_health_failure("timed out", "transport", "connection_failed")
- _assert_health_failure("bad url scheme", "transport", "connection_failed")
- _assert_health_failure("not a timeout", "transport", "connection_failed")
def test_provider_config_rejects_unconfigured_runtime() raises:
@@ -216,6 +227,26 @@ def test_max_local_chat_completions_url_trims_base_url() raises:
)
+def test_max_local_transport_boundary_rejects_invalid_chat_url() raises:
+ var outcome = post_max_local_chat_completion(
+ _invalid_base_url_provider_config(), loads("{}")
+ )
+
+ assert_true(outcome.failure)
+ assert_true(not outcome.response)
+ assert_equal(outcome.failure.value().kind, "transport")
+ assert_equal(outcome.failure.value().reason, "invalid_url")
+
+
+def test_max_local_transport_boundary_rejects_invalid_health_url() raises:
+ var outcome = get_max_local_health(_invalid_health_url_provider_config())
+
+ assert_true(outcome.failure)
+ assert_true(not outcome.response)
+ assert_equal(outcome.failure.value().kind, "transport")
+ assert_equal(outcome.failure.value().reason, "invalid_url")
+
+
def test_query_rewrite_request_body_sets_schema_contract() raises:
var context = default_request_context()
context.return_provenance = True