# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
# pylint: disable=invalid-overridden-method
import asyncio
import logging
import random
from typing import Any, Dict, TYPE_CHECKING
from azure.core.exceptions import AzureError, StreamClosedError, StreamConsumedError
from azure.core.pipeline.policies import AsyncBearerTokenCredentialPolicy, AsyncHTTPPolicy
from .authentication import AzureSigningError, StorageHttpChallenge
from .constants import DEFAULT_OAUTH_SCOPE
from .policies import encode_base64, is_retry, StorageContentValidation, StorageRetryPolicy
if TYPE_CHECKING:
from azure.core.credentials_async import AsyncTokenCredential
from azure.core.pipeline.transport import ( # pylint: disable=non-abstract-transport-import
PipelineRequest,
PipelineResponse
)
_LOGGER = logging.getLogger(__name__)
async def retry_hook(settings, **kwargs):
if settings['hook']:
if asyncio.iscoroutine(settings['hook']):
await settings['hook'](
retry_count=settings['count'] - 1,
location_mode=settings['mode'],
**kwargs)
else:
settings['hook'](
retry_count=settings['count'] - 1,
location_mode=settings['mode'],
**kwargs)
async def is_checksum_retry(response):
# retry if invalid content md5
if response.context.get('validate_content', False) and response.http_response.headers.get('content-md5'):
try:
await response.http_response.load_body() # Load the body in memory and close the socket
except (StreamClosedError, StreamConsumedError):
pass
computed_md5 = response.http_request.headers.get('content-md5', None) or \
encode_base64(StorageContentValidation.get_content_md5(response.http_response.body()))
if response.http_response.headers['content-md5'] != computed_md5:
return True
return False
class AsyncStorageResponseHook(AsyncHTTPPolicy):
def __init__(self, **kwargs):
self._response_callback = kwargs.get('raw_response_hook')
super(AsyncStorageResponseHook, self).__init__()
async def send(self, request: "PipelineRequest") -> "PipelineResponse":
# Values could be 0
data_stream_total = request.context.get('data_stream_total')
if data_stream_total is None:
data_stream_total = request.context.options.pop('data_stream_total', None)
download_stream_current = request.context.get('download_stream_current')
if download_stream_current is None:
download_stream_current = request.context.options.pop('download_stream_current', None)
upload_stream_current = request.context.get('upload_stream_current')
if upload_stream_current is None:
upload_stream_current = request.context.options.pop('upload_stream_current', None)
response_callback = request.context.get('response_callback') or \
request.context.options.pop('raw_response_hook', self._response_callback)
response = await self.next.send(request)
will_retry = is_retry(response, request.context.options.get('mode')) or await is_checksum_retry(response)
# Auth error could come from Bearer challenge, in which case this request will be made again
is_auth_error = response.http_response.status_code == 401
should_update_counts = not (will_retry or is_auth_error)
if should_update_counts and download_stream_current is not None:
download_stream_current += int(response.http_response.headers.get('Content-Length', 0))
if data_stream_total is None:
content_range = response.http_response.headers.get('Content-Range')
if content_range:
data_stream_total = int(content_range.split(' ', 1)[1].split('/', 1)[1])
else:
data_stream_total = download_stream_current
elif should_update_counts and upload_stream_current is not None:
upload_stream_current += int(response.http_request.headers.get('Content-Length', 0))
for pipeline_obj in [request, response]:
if hasattr(pipeline_obj, 'context'):
pipeline_obj.context['data_stream_total'] = data_stream_total
pipeline_obj.context['download_stream_current'] = download_stream_current
pipeline_obj.context['upload_stream_current'] = upload_stream_current
if response_callback:
if asyncio.iscoroutine(response_callback):
await response_callback(response) # type: ignore
else:
response_callback(response)
request.context['response_callback'] = response_callback
return response
class AsyncStorageRetryPolicy(StorageRetryPolicy):
"""
The base class for Exponential and Linear retries containing shared code.
"""
async def sleep(self, settings, transport):
backoff = self.get_backoff_time(settings)
if not backoff or backoff < 0:
return
await transport.sleep(backoff)
async def send(self, request):
retries_remaining = True
response = None
retry_settings = self.configure_retries(request)
while retries_remaining:
try:
response = await self.next.send(request)
if is_retry(response, retry_settings['mode']) or await is_checksum_retry(response):
retries_remaining = self.increment(
retry_settings,
request=request.http_request,
response=response.http_response)
if retries_remaining:
await retry_hook(
retry_settings,
request=request.http_request,
response=response.http_response,
error=None)
await self.sleep(retry_settings, request.context.transport)
continue
break
except AzureError as err:
if isinstance(err, AzureSigningError):
raise
retries_remaining = self.increment(
retry_settings, request=request.http_request, error=err)
if retries_remaining:
await retry_hook(
retry_settings,
request=request.http_request,
response=None,
error=err)
await self.sleep(retry_settings, request.context.transport)
continue
raise err
if retry_settings['history']:
response.context['history'] = retry_settings['history']
response.http_response.location_mode = retry_settings['mode']
return response
[docs]
class ExponentialRetry(AsyncStorageRetryPolicy):
"""Exponential retry."""
initial_backoff: int
"""The initial backoff interval, in seconds, for the first retry."""
increment_base: int
"""The base, in seconds, to increment the initial_backoff by after the
first retry."""
random_jitter_range: int
"""A number in seconds which indicates a range to jitter/randomize for the back-off interval."""
def __init__(
self,
initial_backoff: int = 15,
increment_base: int = 3,
retry_total: int = 3,
retry_to_secondary: bool = False,
random_jitter_range: int = 3, **kwargs
) -> None:
"""
Constructs an Exponential retry object. The initial_backoff is used for
the first retry. Subsequent retries are retried after initial_backoff +
increment_power^retry_count seconds. For example, by default the first retry
occurs after 15 seconds, the second after (15+3^1) = 18 seconds, and the
third after (15+3^2) = 24 seconds.
:param int initial_backoff:
The initial backoff interval, in seconds, for the first retry.
:param int increment_base:
The base, in seconds, to increment the initial_backoff by after the
first retry.
:param int max_attempts:
The maximum number of retry attempts.
:param bool retry_to_secondary:
Whether the request should be retried to secondary, if able. This should
only be enabled of RA-GRS accounts are used and potentially stale data
can be handled.
:param int random_jitter_range:
A number in seconds which indicates a range to jitter/randomize for the back-off interval.
For example, a random_jitter_range of 3 results in the back-off interval x to vary between x+3 and x-3.
"""
self.initial_backoff = initial_backoff
self.increment_base = increment_base
self.random_jitter_range = random_jitter_range
super(ExponentialRetry, self).__init__(
retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs)
[docs]
def get_backoff_time(self, settings: Dict[str, Any]) -> float:
"""
Calculates how long to sleep before retrying.
:param Dict[str, Any] settings: The configurable values pertaining to the backoff time.
:return:
An integer indicating how long to wait before retrying the request,
or None to indicate no retry should be performed.
:rtype: int or None
"""
random_generator = random.Random()
backoff = self.initial_backoff + (0 if settings['count'] == 0 else pow(self.increment_base, settings['count']))
random_range_start = backoff - self.random_jitter_range if backoff > self.random_jitter_range else 0
random_range_end = backoff + self.random_jitter_range
return random_generator.uniform(random_range_start, random_range_end)
[docs]
class LinearRetry(AsyncStorageRetryPolicy):
"""Linear retry."""
initial_backoff: int
"""The backoff interval, in seconds, between retries."""
random_jitter_range: int
"""A number in seconds which indicates a range to jitter/randomize for the back-off interval."""
def __init__(
self, backoff: int = 15,
retry_total: int = 3,
retry_to_secondary: bool = False,
random_jitter_range: int = 3,
**kwargs: Any
) -> None:
"""
Constructs a Linear retry object.
:param int backoff:
The backoff interval, in seconds, between retries.
:param int max_attempts:
The maximum number of retry attempts.
:param bool retry_to_secondary:
Whether the request should be retried to secondary, if able. This should
only be enabled of RA-GRS accounts are used and potentially stale data
can be handled.
:param int random_jitter_range:
A number in seconds which indicates a range to jitter/randomize for the back-off interval.
For example, a random_jitter_range of 3 results in the back-off interval x to vary between x+3 and x-3.
"""
self.backoff = backoff
self.random_jitter_range = random_jitter_range
super(LinearRetry, self).__init__(
retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs)
[docs]
def get_backoff_time(self, settings: Dict[str, Any]) -> float:
"""
Calculates how long to sleep before retrying.
:param Dict[str, Any] settings: The configurable values pertaining to the backoff time.
:return:
An integer indicating how long to wait before retrying the request,
or None to indicate no retry should be performed.
:rtype: int or None
"""
random_generator = random.Random()
# the backoff interval normally does not change, however there is the possibility
# that it was modified by accessing the property directly after initializing the object
random_range_start = self.backoff - self.random_jitter_range \
if self.backoff > self.random_jitter_range else 0
random_range_end = self.backoff + self.random_jitter_range
return random_generator.uniform(random_range_start, random_range_end)
class AsyncStorageBearerTokenCredentialPolicy(AsyncBearerTokenCredentialPolicy):
""" Custom Bearer token credential policy for following Storage Bearer challenges """
def __init__(self, credential: "AsyncTokenCredential", audience: str, **kwargs: Any) -> None:
super(AsyncStorageBearerTokenCredentialPolicy, self).__init__(credential, audience, **kwargs)
async def on_challenge(self, request: "PipelineRequest", response: "PipelineResponse") -> bool:
try:
auth_header = response.http_response.headers.get("WWW-Authenticate")
challenge = StorageHttpChallenge(auth_header)
except ValueError:
return False
scope = challenge.resource_id + DEFAULT_OAUTH_SCOPE
await self.authorize_request(request, scope, tenant_id=challenge.tenant_id)
return True