diff --git a/runpod/endpoint/helpers.py b/runpod/endpoint/helpers.py index cbdb83f9..acef8b4e 100644 --- a/runpod/endpoint/helpers.py +++ b/runpod/endpoint/helpers.py @@ -1,6 +1,6 @@ """ Helper functions for the Runpod Endpoint API. """ -FINAL_STATES = ["COMPLETED", "FAILED", "TIMED_OUT"] +FINAL_STATES = ["COMPLETED", "FAILED", "TIMED_OUT", "CANCELLED"] # Exception Messages UNAUTHORIZED_MSG = "401 Unauthorized | Make sure Runpod API key is set and valid." @@ -14,4 +14,4 @@ def is_completed(status: str) -> bool: """Returns true if status is one of the possible final states for a serverless request.""" - return status in ["COMPLETED", "FAILED", "TIMED_OUT", "CANCELLED"] + return status in FINAL_STATES diff --git a/runpod/endpoint/runner.py b/runpod/endpoint/runner.py index d8c48759..a3127edd 100644 --- a/runpod/endpoint/runner.py +++ b/runpod/endpoint/runner.py @@ -44,8 +44,17 @@ def __init__(self, api_key: Optional[str] = None): raise RuntimeError(API_KEY_NOT_SET_MSG) self.rp_session = requests.Session() - retries = Retry(total=5, backoff_factor=1, status_forcelist=[408, 429]) - self.rp_session.mount("http://", HTTPAdapter(max_retries=retries)) + retries = Retry( + total=5, + backoff_factor=1, + status_forcelist=[408, 429, 500, 502, 503, 504], + allowed_methods=frozenset(["GET", "POST"]), + ) + adapter = HTTPAdapter(max_retries=retries) + # The production API is served over https; mount on both schemes so the + # retry/backoff policy actually applies to real traffic. + self.rp_session.mount("https://", adapter) + self.rp_session.mount("http://", adapter) self.headers = { "Content-Type": "application/json", diff --git a/runpod/serverless/modules/rp_http.py b/runpod/serverless/modules/rp_http.py index 3d82d35b..b326fa9a 100644 --- a/runpod/serverless/modules/rp_http.py +++ b/runpod/serverless/modules/rp_http.py @@ -26,7 +26,7 @@ log = RunPodLogger() -async def _transmit(client_session: ClientSession, url, job_data): +async def _transmit(client_session: ClientSession, url, job_data, request_id=None): """ Wrapper for transmitting results via POST. """ @@ -35,12 +35,18 @@ async def _transmit(client_session: ClientSession, url, job_data): client_session=client_session, retry_options=retry_options ) + headers = { + "charset": "utf-8", + "Content-Type": "application/x-www-form-urlencoded", + } + # Pass the request id per-request rather than mutating the shared session's + # headers, which would race across concurrently handled jobs. + if request_id is not None: + headers["X-Request-ID"] = request_id + kwargs = { "data": job_data, - "headers": { - "charset": "utf-8", - "Content-Type": "application/x-www-form-urlencoded", - }, + "headers": headers, "raise_for_status": True, } @@ -55,14 +61,12 @@ async def _handle_result( A helper function to handle the result, either for sending or streaming. """ try: - session.headers["X-Request-ID"] = job["id"] - serialized_job_data = json.dumps(job_data, ensure_ascii=False) is_stream = "true" if is_stream else "false" url = url_template.replace("$ID", job["id"]) + f"&isStream={is_stream}" - await _transmit(session, url, serialized_job_data) + await _transmit(session, url, serialized_job_data, request_id=job["id"]) log.debug(f"{log_message}", job["id"]) except ClientError as err: diff --git a/tests/test_serverless/test_modules/test_http.py b/tests/test_serverless/test_modules/test_http.py index c7462245..62249907 100644 --- a/tests/test_serverless/test_modules/test_http.py +++ b/tests/test_serverless/test_modules/test_http.py @@ -65,6 +65,7 @@ async def test_send_result(self): headers={ "charset": "utf-8", "Content-Type": "application/x-www-form-urlencoded", + "X-Request-ID": self.job["id"], }, raise_for_status=True, ) @@ -159,6 +160,7 @@ async def test_stream_result(self): headers={ "charset": "utf-8", "Content-Type": "application/x-www-form-urlencoded", + "X-Request-ID": self.job["id"], }, raise_for_status=True, )