commit b06d3ffcdde613510a5819df5ffa3b082450c5d3
parent d3a19bc664bafdda49a60055118c321b19dba4ad
Author: triesap <tyson@radroots.org>
Date: Thu, 18 Jun 2026 13:46:12 -0700
provider: type max local execution failures
- return explicit MAX-local success and failure outcomes
- move stdio fallback handling to provider outcome reasons
- add typed health and completion failure reason checks
- keep provider error public reason strings stable
Diffstat:
4 files changed, 249 insertions(+), 58 deletions(-)
diff --git a/src/hyf_provider/health.mojo b/src/hyf_provider/health.mojo
@@ -4,6 +4,16 @@ from hyf_provider.config import MaxLocalProviderConfig
from hyf_provider.result import MaxLocalProviderStatus
+@fieldwise_init
+struct MaxLocalHealthFailure(Copyable, Movable):
+ var kind: String
+ var reason: String
+
+
+def _health_failure(kind: String, reason: String) -> MaxLocalHealthFailure:
+ return MaxLocalHealthFailure(kind=String(kind), reason=String(reason))
+
+
def _provider_status(
config: MaxLocalProviderConfig,
reachable: Bool,
@@ -21,13 +31,22 @@ def _provider_status(
)
-def _classify_health_error(message: String) -> String:
+def max_local_health_failure_from_error(
+ message: String,
+) -> MaxLocalHealthFailure:
+ if message == "invalid_url" or message.find("invalid_url") >= 0:
+ return _health_failure("transport", "invalid_url")
+ if message == "timeout" or message.find("timeout") >= 0:
+ return _health_failure("transport", "timeout")
+ if message == "connection_failed" or message.find("connection_failed") >= 0:
+ return _health_failure("transport", "connection_failed")
+
var lower = message.lower()
- if lower.find("timeout") >= 0 or lower.find("timed out") >= 0:
- return "timeout"
if lower.find("url") >= 0 or lower.find("scheme") >= 0:
- return "invalid_url"
- return "connection_failed"
+ return _health_failure("transport", "invalid_url")
+ if lower.find("timeout") >= 0 or lower.find("timed out") >= 0:
+ return _health_failure("transport", "timeout")
+ return _health_failure("transport", "connection_failed")
def resolve_max_local_provider_status(
@@ -40,9 +59,10 @@ def resolve_max_local_provider_status(
return _provider_status(config, True, "ready", "ready")
return _provider_status(config, False, "unavailable", "non_2xx")
except e:
+ var failure = max_local_health_failure_from_error(String(e))
return _provider_status(
config,
False,
"unavailable",
- _classify_health_error(String(e)),
+ String(failure.reason),
)
diff --git a/src/hyf_provider/max_local.mojo b/src/hyf_provider/max_local.mojo
@@ -1,3 +1,4 @@
+from std.collections import Optional
from std.time import perf_counter_ns
from hyf_assist.contract import max_local_query_rewrite_route
@@ -31,30 +32,150 @@ struct MaxLocalQueryRewriteResult(Copyable, Movable):
var prompt_version: String
+@fieldwise_init
+struct MaxLocalQueryRewriteFailure(Copyable, Movable):
+ var kind: String
+ var reason: String
+
+
+@fieldwise_init
+struct MaxLocalQueryRewriteOutcome(Copyable, Movable):
+ var result: Optional[MaxLocalQueryRewriteResult]
+ var failure: Optional[MaxLocalQueryRewriteFailure]
+
+
+def _query_rewrite_success_outcome(
+ result: MaxLocalQueryRewriteResult
+) -> MaxLocalQueryRewriteOutcome:
+ return MaxLocalQueryRewriteOutcome(
+ result=Optional[MaxLocalQueryRewriteResult](result.copy()),
+ failure=Optional[MaxLocalQueryRewriteFailure](None),
+ )
+
+
+def _query_rewrite_failure_outcome(
+ kind: String, reason: String
+) -> MaxLocalQueryRewriteOutcome:
+ return MaxLocalQueryRewriteOutcome(
+ result=Optional[MaxLocalQueryRewriteResult](None),
+ failure=Optional[MaxLocalQueryRewriteFailure](
+ MaxLocalQueryRewriteFailure(
+ kind=String(kind), reason=String(reason)
+ )
+ ),
+ )
+
+
+def _matches_provider_reason(message: String, reason: String) -> Bool:
+ return message == reason or message.find(reason) >= 0
+
+
+def max_local_query_rewrite_failure_from_error(
+ message: String,
+) -> MaxLocalQueryRewriteFailure:
+ if _matches_provider_reason(message, "invalid_url"):
+ return MaxLocalQueryRewriteFailure(
+ kind="transport", reason="invalid_url"
+ )
+ if _matches_provider_reason(message, "provider_non_2xx"):
+ return MaxLocalQueryRewriteFailure(
+ kind="http_status", reason="provider_non_2xx"
+ )
+ if _matches_provider_reason(message, "provider_error_payload"):
+ return MaxLocalQueryRewriteFailure(
+ kind="provider_payload", reason="provider_error_payload"
+ )
+ if _matches_provider_reason(message, "provider_invalid_json"):
+ return MaxLocalQueryRewriteFailure(
+ kind="provider_payload", reason="provider_invalid_json"
+ )
+ if _matches_provider_reason(message, "provider_schema_invalid"):
+ return MaxLocalQueryRewriteFailure(
+ kind="provider_payload", reason="provider_schema_invalid"
+ )
+ if _matches_provider_reason(message, "provider_empty_choices"):
+ return MaxLocalQueryRewriteFailure(
+ kind="provider_payload", reason="provider_empty_choices"
+ )
+ if _matches_provider_reason(message, "provider_missing_content"):
+ return MaxLocalQueryRewriteFailure(
+ kind="provider_payload", reason="provider_missing_content"
+ )
+ if _matches_provider_reason(message, "provider_invalid_response"):
+ return MaxLocalQueryRewriteFailure(
+ kind="provider_payload", reason="provider_invalid_response"
+ )
+
+ var lower = message.lower()
+ if lower.find("url") >= 0 or lower.find("scheme") >= 0:
+ return MaxLocalQueryRewriteFailure(
+ kind="transport", reason="invalid_url"
+ )
+ if lower.find("timeout") >= 0 or lower.find("timed out") >= 0:
+ return MaxLocalQueryRewriteFailure(
+ kind="transport", reason="timeout"
+ )
+ if lower.find("connection") >= 0:
+ return MaxLocalQueryRewriteFailure(
+ kind="transport", reason="connection_failed"
+ )
+ return MaxLocalQueryRewriteFailure(
+ kind="provider", reason="provider_error"
+ )
+
+
def execute_query_rewrite_via_max_local_provider(
config: MaxLocalProviderConfig, text: String, context: RequestContext
) raises -> MaxLocalQueryRewriteResult:
+ var outcome = try_execute_query_rewrite_via_max_local_provider(
+ config, text, context
+ )
+ if outcome.result:
+ return outcome.result.value().copy()
+ if outcome.failure:
+ raise Error(String(outcome.failure.value().reason))
+ raise Error("provider_error")
+
+
+def try_execute_query_rewrite_via_max_local_provider(
+ config: MaxLocalProviderConfig, text: String, context: RequestContext
+) -> MaxLocalQueryRewriteOutcome:
with make_max_local_http_client(config) as client:
- var start_ns = perf_counter_ns()
- var response = client.post(
- max_local_chat_completions_url(config),
- build_query_rewrite_request_body(config, text, context),
- )
- var latency_ms = Int((perf_counter_ns() - start_ns) // 1_000_000)
- if not response.ok():
- raise Error("provider_non_2xx")
+ try:
+ var start_ns = perf_counter_ns()
+ var response = client.post(
+ max_local_chat_completions_url(config),
+ build_query_rewrite_request_body(config, text, context),
+ )
+ var latency_ms = Int(
+ (perf_counter_ns() - start_ns) // 1_000_000
+ )
+ if not response.ok():
+ return _query_rewrite_failure_outcome(
+ "http_status", "provider_non_2xx"
+ )
- return MaxLocalQueryRewriteResult(
- analysis=parse_query_analysis_from_chat_completion(
+ var analysis = parse_query_analysis_from_chat_completion(
response.json()
- ),
- provider="max_local",
- route=max_local_query_rewrite_route(),
- model=String(config.model),
- latency_ms=latency_ms,
- schema_version=query_rewrite_schema_version(),
- prompt_version=query_rewrite_prompt_version(),
- )
+ )
+ return _query_rewrite_success_outcome(
+ MaxLocalQueryRewriteResult(
+ analysis=analysis^,
+ provider="max_local",
+ route=max_local_query_rewrite_route(),
+ model=String(config.model),
+ latency_ms=latency_ms,
+ schema_version=query_rewrite_schema_version(),
+ prompt_version=query_rewrite_prompt_version(),
+ )
+ )
+ except e:
+ var failure = max_local_query_rewrite_failure_from_error(
+ String(e)
+ )
+ return _query_rewrite_failure_outcome(
+ String(failure.kind), String(failure.reason)
+ )
def max_local_provider_status(
diff --git a/src/hyf_stdio/provider_execution.mojo b/src/hyf_stdio/provider_execution.mojo
@@ -38,8 +38,8 @@ from hyf_provider.config import (
)
from hyf_provider.max_local import (
MaxLocalQueryRewriteResult,
- execute_query_rewrite_via_max_local_provider,
max_local_provider_status,
+ try_execute_query_rewrite_via_max_local_provider,
)
from hyf_runtime.config import (
HyfLoadedRuntimeConfig,
@@ -107,37 +107,6 @@ def _provider_runtime_config_fallback_reason(
return Optional[String](None)
-def _provider_execution_error_reason(message: String) -> String:
- var lower = message.lower()
- if lower.find("timeout") >= 0 or lower.find("timed out") >= 0:
- return "timeout"
- if lower.find("connection") >= 0:
- return "connection_failed"
- if lower.find("provider_non_2xx") >= 0 or lower.find("http") >= 0:
- return "provider_non_2xx"
- if lower.find("provider_error_payload") >= 0:
- return "provider_error_payload"
- if lower.find("provider_invalid_json") >= 0:
- return "provider_invalid_json"
- if lower.find("provider_schema_invalid") >= 0:
- return "provider_schema_invalid"
- if lower.find("provider_empty_choices") >= 0:
- return "provider_empty_choices"
- if lower.find("provider_missing_content") >= 0:
- return "provider_missing_content"
- if lower.find("provider_invalid_response") >= 0:
- return "provider_invalid_response"
- if lower.find("choices") >= 0:
- return "provider_empty_choices"
- if lower.find("content") >= 0 or lower.find("message") >= 0:
- return "provider_missing_content"
- if lower.find("json") >= 0 or lower.find("parse") >= 0:
- return "provider_invalid_json"
- if lower.find("schema") >= 0 or lower.find("required") >= 0:
- return "provider_schema_invalid"
- return "provider_error"
-
-
def _effective_provider_budget_ms(
config: MaxLocalProviderConfig, context: RequestContext
) -> Int:
@@ -270,9 +239,25 @@ def _execute_query_rewrite_with_provider(
provider_config = _provider_config_with_timeout(
provider_config, remaining_ms
)
- var result = execute_query_rewrite_via_max_local_provider(
+ var outcome = try_execute_query_rewrite_via_max_local_provider(
provider_config, request.text, context
)
+ if outcome.failure:
+ return _query_rewrite_fallback(
+ input,
+ context,
+ "provider_runtime",
+ String(outcome.failure.value().reason),
+ )
+ if not outcome.result:
+ return _query_rewrite_fallback(
+ input,
+ context,
+ "provider_runtime",
+ "provider_error",
+ )
+
+ var result = outcome.result.value().copy()
return successful_capability(
build_query_rewrite_output(result.analysis),
meta=_provider_meta(context, result),
@@ -282,7 +267,7 @@ def _execute_query_rewrite_with_provider(
input,
context,
"provider_runtime",
- _provider_execution_error_reason(String(e)),
+ "provider_error",
)
diff --git a/tests/test_provider_adapter.mojo b/tests/test_provider_adapter.mojo
@@ -9,6 +9,8 @@ from hyf_provider.config import (
MaxLocalProviderConfig,
max_local_provider_config_from_runtime,
)
+from hyf_provider.health import max_local_health_failure_from_error
+from hyf_provider.max_local import max_local_query_rewrite_failure_from_error
from hyf_provider.result import parse_query_analysis_from_chat_completion
from hyf_provider.schema import build_query_rewrite_request_body
from hyf_runtime.config import (
@@ -86,6 +88,22 @@ def _chat_completion_response() raises -> Value:
return response^
+def _assert_query_rewrite_failure(
+ message: String, expected_kind: String, expected_reason: String
+) raises:
+ var failure = max_local_query_rewrite_failure_from_error(message)
+ assert_equal(failure.kind, expected_kind)
+ assert_equal(failure.reason, expected_reason)
+
+
+def _assert_health_failure(
+ message: String, expected_kind: String, expected_reason: String
+) raises:
+ var failure = max_local_health_failure_from_error(message)
+ assert_equal(failure.kind, expected_kind)
+ assert_equal(failure.reason, expected_reason)
+
+
def test_provider_config_maps_runtime_config() raises:
var config = max_local_provider_config_from_runtime(
_provider_runtime_config()
@@ -104,6 +122,53 @@ def test_max_local_route_is_derived_from_assisted_contract() raises:
)
+def test_max_local_provider_failure_mapping_preserves_reason_tokens() raises:
+ _assert_query_rewrite_failure("timed out", "transport", "timeout")
+ _assert_query_rewrite_failure(
+ "connection refused", "transport", "connection_failed"
+ )
+ _assert_query_rewrite_failure("bad url scheme", "transport", "invalid_url")
+ _assert_query_rewrite_failure(
+ "provider_non_2xx", "http_status", "provider_non_2xx"
+ )
+ _assert_query_rewrite_failure(
+ "provider_error_payload",
+ "provider_payload",
+ "provider_error_payload",
+ )
+ _assert_query_rewrite_failure(
+ "provider_invalid_json",
+ "provider_payload",
+ "provider_invalid_json",
+ )
+ _assert_query_rewrite_failure(
+ "provider_schema_invalid",
+ "provider_payload",
+ "provider_schema_invalid",
+ )
+ _assert_query_rewrite_failure(
+ "provider_empty_choices",
+ "provider_payload",
+ "provider_empty_choices",
+ )
+ _assert_query_rewrite_failure(
+ "provider_missing_content",
+ "provider_payload",
+ "provider_missing_content",
+ )
+ _assert_query_rewrite_failure(
+ "unexpected provider failure", "provider", "provider_error"
+ )
+
+
+def test_max_local_health_failure_mapping_preserves_reason_tokens() raises:
+ _assert_health_failure("timed out", "transport", "timeout")
+ _assert_health_failure("bad url scheme", "transport", "invalid_url")
+ _assert_health_failure(
+ "unexpected health failure", "transport", "connection_failed"
+ )
+
+
def test_provider_config_rejects_unconfigured_runtime() raises:
with assert_raises():
_ = max_local_provider_config_from_runtime(