commit d3a19bc664bafdda49a60055118c321b19dba4ad
parent e13907c7ae4aac5ad3d5ecdb19b82f38ff092528
Author: triesap <tyson@radroots.org>
Date: Thu, 18 Jun 2026 13:31:16 -0700
provider: enforce max local request deadline
- treat assisted deadline as one provider budget
- pass remaining budget to completion after readiness
- cover remaining-deadline timeout with the max local stub
Diffstat:
3 files changed, 99 insertions(+), 18 deletions(-)
diff --git a/src/hyf_stdio/provider_execution.mojo b/src/hyf_stdio/provider_execution.mojo
@@ -1,4 +1,5 @@
from std.collections import List, Optional
+from std.time import perf_counter_ns
from json import Value
@@ -137,18 +138,29 @@ def _provider_execution_error_reason(message: String) -> String:
return "provider_error"
-def _max_local_config_for_request(
- runtime_context: RuntimeStartupContext, context: RequestContext
-) raises -> MaxLocalProviderConfig:
- var config = max_local_provider_config_from_runtime(
- runtime_context.config
- )
- if (
- context.deadline_ms > 0
- and config.request_timeout_ms > context.deadline_ms
- ):
- config.request_timeout_ms = context.deadline_ms
- return config^
+def _effective_provider_budget_ms(
+ config: MaxLocalProviderConfig, context: RequestContext
+) -> Int:
+ var budget_ms = config.request_timeout_ms
+ if context.deadline_ms > 0 and context.deadline_ms < budget_ms:
+ budget_ms = context.deadline_ms
+ return budget_ms
+
+
+def _provider_config_with_timeout(
+ config: MaxLocalProviderConfig, timeout_ms: Int
+) -> MaxLocalProviderConfig:
+ var capped = config.copy()
+ capped.request_timeout_ms = timeout_ms
+ return capped^
+
+
+def _remaining_provider_budget_ms(start_ns: UInt, budget_ms: Int) -> Int:
+ var elapsed_ms = Int((perf_counter_ns() - start_ns) // 1_000_000)
+ var remaining_ms = budget_ms - elapsed_ms
+ if remaining_ms <= 0:
+ return 0
+ return remaining_ms
def _query_rewrite_fallback(
@@ -214,10 +226,16 @@ def _execute_query_rewrite_with_provider(
)
try:
- var provider_config = _max_local_config_for_request(
- runtime_context, context
+ var provider_config = max_local_provider_config_from_runtime(
+ runtime_context.config
+ )
+ var budget_ms = _effective_provider_budget_ms(
+ provider_config, context
+ )
+ var budget_start_ns = perf_counter_ns()
+ var provider_status = max_local_provider_status(
+ _provider_config_with_timeout(provider_config, budget_ms)
)
- var provider_status = max_local_provider_status(provider_config)
if provider_status.state != "ready":
return _query_rewrite_fallback(
input,
@@ -226,7 +244,32 @@ def _execute_query_rewrite_with_provider(
String(provider_status.reason),
)
+ var remaining_ms = _remaining_provider_budget_ms(
+ budget_start_ns, budget_ms
+ )
+ if remaining_ms <= 0:
+ return _query_rewrite_fallback(
+ input,
+ context,
+ "provider_runtime",
+ "timeout",
+ )
+
var request = parse_query_rewrite_request(input)
+ remaining_ms = _remaining_provider_budget_ms(
+ budget_start_ns, budget_ms
+ )
+ if remaining_ms <= 0:
+ return _query_rewrite_fallback(
+ input,
+ context,
+ "provider_runtime",
+ "timeout",
+ )
+
+ provider_config = _provider_config_with_timeout(
+ provider_config, remaining_ms
+ )
var result = execute_query_rewrite_via_max_local_provider(
provider_config, request.text, context
)
diff --git a/tests/max_local_process_helper.mojo b/tests/max_local_process_helper.mojo
@@ -122,6 +122,9 @@ def _handle_health(mut stream: TcpStream, mode: String) raises:
_send(stream, 503, '{"status":"unavailable"}')
elif mode == "health_timeout":
usleep(1_000_000)
+ elif mode == "query_rewrite_remaining_deadline_timeout":
+ usleep(200_000)
+ _send(stream, 200, '{"status":"ok"}')
else:
_send(stream, 200, '{"status":"ok"}')
@@ -156,6 +159,9 @@ def _handle_chat_completions(mut stream: TcpStream, mode: String) raises:
elif mode == "query_rewrite_timeout":
usleep(2_000_000)
_send(stream, 200, _chat_completion(_query_rewrite_analysis()))
+ elif mode == "query_rewrite_remaining_deadline_timeout":
+ usleep(400_000)
+ _send(stream, 200, _chat_completion(_query_rewrite_analysis()))
else:
_send(stream, 500, '{"error":"unsupported_mode"}')
diff --git a/tests/test_stdio_contract.mojo b/tests/test_stdio_contract.mojo
@@ -74,12 +74,22 @@ def _unavailable_max_local_runtime_config_toml() raises -> String:
def _query_rewrite_assisted_request_json(request_id: String) -> String:
+ return _query_rewrite_assisted_request_json_with_deadline(
+ request_id, 2500
+ )
+
+
+def _query_rewrite_assisted_request_json_with_deadline(
+ request_id: String, deadline_ms: Int
+) -> String:
return (
'{"version":1,"request_id":"'
+ request_id
+ '","trace_id":"'
+ request_id
- + '","capability":"query_rewrite","context":{"execution_mode_preference":"assisted","return_provenance":true},"input":{"query":"apples near me with weekend pickup"}}'
+ + '","capability":"query_rewrite","context":{"execution_mode_preference":"assisted","return_provenance":true,"deadline_ms":'
+ + String(deadline_ms)
+ + '},"input":{"query":"apples near me with weekend pickup"}}'
)
@@ -96,6 +106,18 @@ def _semantic_rank_assisted_request_json(request_id: String) -> String:
def _assert_query_rewrite_provider_fallback_with_requests(
mode: String, expected_reason: String, request_timeout_ms: Int, requests: Int
) raises:
+ _assert_query_rewrite_provider_fallback_with_deadline(
+ mode, expected_reason, request_timeout_ms, 2500, requests
+ )
+
+
+def _assert_query_rewrite_provider_fallback_with_deadline(
+ mode: String,
+ expected_reason: String,
+ request_timeout_ms: Int,
+ deadline_ms: Int,
+ requests: Int,
+) raises:
with TemporaryDirectory() as temp_dir:
var provider_port = reserve_loopback_port()
var provider_stub = spawn_max_local_stub(provider_port, mode, requests)
@@ -111,8 +133,8 @@ def _assert_query_rewrite_provider_fallback_with_requests(
with ScopedEnvVar(HYF_PATHS_REPO_LOCAL_ROOT_ENV, temp_dir):
var response = run_stdio_entrypoint(
"src/main.mojo",
- _query_rewrite_assisted_request_json(
- "rewrite-assisted-" + mode
+ _query_rewrite_assisted_request_json_with_deadline(
+ "rewrite-assisted-" + mode, deadline_ms
),
"--config",
startup_config_path.__fspath__(),
@@ -1409,6 +1431,16 @@ def test_query_rewrite_falls_back_when_provider_readiness_probe_times_out() rais
)
+def test_query_rewrite_completion_uses_remaining_deadline_after_readiness() raises:
+ _assert_query_rewrite_provider_fallback_with_deadline(
+ "query_rewrite_remaining_deadline_timeout",
+ "timeout",
+ 1000,
+ 500,
+ 2,
+ )
+
+
def test_query_rewrite_falls_back_on_provider_invalid_json() raises:
_assert_query_rewrite_provider_fallback(
"query_rewrite_invalid_json", "provider_invalid_json", 15000