# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
import base64
import hashlib
import logging
import random
import re
import uuid
from io import SEEK_SET, UnsupportedOperation
from time import time
from typing import Any, Dict, Optional, TYPE_CHECKING
from urllib.parse import (
parse_qsl,
urlencode,
urlparse,
urlunparse,
)
from wsgiref.handlers import format_date_time
from azure.core.exceptions import AzureError, ServiceRequestError, ServiceResponseError
from azure.core.pipeline.policies import (
BearerTokenCredentialPolicy,
HeadersPolicy,
HTTPPolicy,
NetworkTraceLoggingPolicy,
RequestHistory,
SansIOHTTPPolicy
)
from .authentication import AzureSigningError, StorageHttpChallenge
from .constants import DEFAULT_OAUTH_SCOPE
from .models import LocationMode
if TYPE_CHECKING:
from azure.core.credentials import TokenCredential
from azure.core.pipeline.transport import ( # pylint: disable=non-abstract-transport-import
PipelineRequest,
PipelineResponse
)
_LOGGER = logging.getLogger(__name__)
def encode_base64(data):
if isinstance(data, str):
data = data.encode('utf-8')
encoded = base64.b64encode(data)
return encoded.decode('utf-8')
# Are we out of retries?
def is_exhausted(settings):
retry_counts = (settings['total'], settings['connect'], settings['read'], settings['status'])
retry_counts = list(filter(None, retry_counts))
if not retry_counts:
return False
return min(retry_counts) < 0
def retry_hook(settings, **kwargs):
if settings['hook']:
settings['hook'](retry_count=settings['count'] - 1, location_mode=settings['mode'], **kwargs)
# Is this method/status code retryable? (Based on allowlists and control
# variables such as the number of total retries to allow, whether to
# respect the Retry-After header, whether this header is present, and
# whether the returned status code is on the list of status codes to
# be retried upon on the presence of the aforementioned header)
def is_retry(response, mode):
status = response.http_response.status_code
if 300 <= status < 500:
# An exception occurred, but in most cases it was expected. Examples could
# include a 309 Conflict or 412 Precondition Failed.
if status == 404 and mode == LocationMode.SECONDARY:
# Response code 404 should be retried if secondary was used.
return True
if status == 408:
# Response code 408 is a timeout and should be retried.
return True
return False
if status >= 500:
# Response codes above 500 with the exception of 501 Not Implemented and
# 505 Version Not Supported indicate a server issue and should be retried.
if status in [501, 505]:
return False
return True
return False
def is_checksum_retry(response):
# retry if invalid content md5
if response.context.get('validate_content', False) and response.http_response.headers.get('content-md5'):
computed_md5 = response.http_request.headers.get('content-md5', None) or \
encode_base64(StorageContentValidation.get_content_md5(response.http_response.body()))
if response.http_response.headers['content-md5'] != computed_md5:
return True
return False
def urljoin(base_url, stub_url):
parsed = urlparse(base_url)
parsed = parsed._replace(path=parsed.path + '/' + stub_url)
return parsed.geturl()
class QueueMessagePolicy(SansIOHTTPPolicy):
def on_request(self, request):
message_id = request.context.options.pop('queue_message_id', None)
if message_id:
request.http_request.url = urljoin(
request.http_request.url,
message_id)
class StorageHeadersPolicy(HeadersPolicy):
request_id_header_name = 'x-ms-client-request-id'
def on_request(self, request: "PipelineRequest") -> None:
super(StorageHeadersPolicy, self).on_request(request)
current_time = format_date_time(time())
request.http_request.headers['x-ms-date'] = current_time
custom_id = request.context.options.pop('client_request_id', None)
request.http_request.headers['x-ms-client-request-id'] = custom_id or str(uuid.uuid1())
# def on_response(self, request, response):
# # raise exception if the echoed client request id from the service is not identical to the one we sent
# if self.request_id_header_name in response.http_response.headers:
# client_request_id = request.http_request.headers.get(self.request_id_header_name)
# if response.http_response.headers[self.request_id_header_name] != client_request_id:
# raise AzureError(
# "Echoed client request ID: {} does not match sent client request ID: {}. "
# "Service request ID: {}".format(
# response.http_response.headers[self.request_id_header_name], client_request_id,
# response.http_response.headers['x-ms-request-id']),
# response=response.http_response
# )
class StorageHosts(SansIOHTTPPolicy):
def __init__(self, hosts=None, **kwargs): # pylint: disable=unused-argument
self.hosts = hosts
super(StorageHosts, self).__init__()
def on_request(self, request: "PipelineRequest") -> None:
request.context.options['hosts'] = self.hosts
parsed_url = urlparse(request.http_request.url)
# Detect what location mode we're currently requesting with
location_mode = LocationMode.PRIMARY
for key, value in self.hosts.items():
if parsed_url.netloc == value:
location_mode = key
# See if a specific location mode has been specified, and if so, redirect
use_location = request.context.options.pop('use_location', None)
if use_location:
# Lock retries to the specific location
request.context.options['retry_to_secondary'] = False
if use_location not in self.hosts:
raise ValueError(f"Attempting to use undefined host location {use_location}")
if use_location != location_mode:
# Update request URL to use the specified location
updated = parsed_url._replace(netloc=self.hosts[use_location])
request.http_request.url = updated.geturl()
location_mode = use_location
request.context.options['location_mode'] = location_mode
class StorageLoggingPolicy(NetworkTraceLoggingPolicy):
"""A policy that logs HTTP request and response to the DEBUG logger.
This accepts both global configuration, and per-request level with "enable_http_logger"
"""
def __init__(self, logging_enable: bool = False, **kwargs) -> None:
self.logging_body = kwargs.pop("logging_body", False)
super(StorageLoggingPolicy, self).__init__(logging_enable=logging_enable, **kwargs)
def on_request(self, request: "PipelineRequest") -> None:
http_request = request.http_request
options = request.context.options
self.logging_body = self.logging_body or options.pop("logging_body", False)
if options.pop("logging_enable", self.enable_http_logger):
request.context["logging_enable"] = True
if not _LOGGER.isEnabledFor(logging.DEBUG):
return
try:
log_url = http_request.url
query_params = http_request.query
if 'sig' in query_params:
log_url = log_url.replace(query_params['sig'], "sig=*****")
_LOGGER.debug("Request URL: %r", log_url)
_LOGGER.debug("Request method: %r", http_request.method)
_LOGGER.debug("Request headers:")
for header, value in http_request.headers.items():
if header.lower() == 'authorization':
value = '*****'
elif header.lower() == 'x-ms-copy-source' and 'sig' in value:
# take the url apart and scrub away the signed signature
scheme, netloc, path, params, query, fragment = urlparse(value)
parsed_qs = dict(parse_qsl(query))
parsed_qs['sig'] = '*****'
# the SAS needs to be put back together
value = urlunparse((scheme, netloc, path, params, urlencode(parsed_qs), fragment))
_LOGGER.debug(" %r: %r", header, value)
_LOGGER.debug("Request body:")
if self.logging_body:
_LOGGER.debug(str(http_request.body))
else:
# We don't want to log the binary data of a file upload.
_LOGGER.debug("Hidden body, please use logging_body to show body")
except Exception as err: # pylint: disable=broad-except
_LOGGER.debug("Failed to log request: %r", err)
def on_response(self, request: "PipelineRequest", response: "PipelineResponse") -> None:
if response.context.pop("logging_enable", self.enable_http_logger):
if not _LOGGER.isEnabledFor(logging.DEBUG):
return
try:
_LOGGER.debug("Response status: %r", response.http_response.status_code)
_LOGGER.debug("Response headers:")
for res_header, value in response.http_response.headers.items():
_LOGGER.debug(" %r: %r", res_header, value)
# We don't want to log binary data if the response is a file.
_LOGGER.debug("Response content:")
pattern = re.compile(r'attachment; ?filename=["\w.]+', re.IGNORECASE)
header = response.http_response.headers.get('content-disposition')
resp_content_type = response.http_response.headers.get("content-type", "")
if header and pattern.match(header):
filename = header.partition('=')[2]
_LOGGER.debug("File attachments: %s", filename)
elif resp_content_type.endswith("octet-stream"):
_LOGGER.debug("Body contains binary data.")
elif resp_content_type.startswith("image"):
_LOGGER.debug("Body contains image data.")
if self.logging_body and resp_content_type.startswith("text"):
_LOGGER.debug(response.http_response.text())
elif self.logging_body:
try:
_LOGGER.debug(response.http_response.body())
except ValueError:
_LOGGER.debug("Body is streamable")
except Exception as err: # pylint: disable=broad-except
_LOGGER.debug("Failed to log response: %s", repr(err))
class StorageRequestHook(SansIOHTTPPolicy):
def __init__(self, **kwargs):
self._request_callback = kwargs.get('raw_request_hook')
super(StorageRequestHook, self).__init__()
def on_request(self, request: "PipelineRequest") -> None:
request_callback = request.context.options.pop('raw_request_hook', self._request_callback)
if request_callback:
request_callback(request)
class StorageResponseHook(HTTPPolicy):
def __init__(self, **kwargs):
self._response_callback = kwargs.get('raw_response_hook')
super(StorageResponseHook, self).__init__()
def send(self, request: "PipelineRequest") -> "PipelineResponse":
# Values could be 0
data_stream_total = request.context.get('data_stream_total')
if data_stream_total is None:
data_stream_total = request.context.options.pop('data_stream_total', None)
download_stream_current = request.context.get('download_stream_current')
if download_stream_current is None:
download_stream_current = request.context.options.pop('download_stream_current', None)
upload_stream_current = request.context.get('upload_stream_current')
if upload_stream_current is None:
upload_stream_current = request.context.options.pop('upload_stream_current', None)
response_callback = request.context.get('response_callback') or \
request.context.options.pop('raw_response_hook', self._response_callback)
response = self.next.send(request)
will_retry = is_retry(response, request.context.options.get('mode')) or is_checksum_retry(response)
# Auth error could come from Bearer challenge, in which case this request will be made again
is_auth_error = response.http_response.status_code == 401
should_update_counts = not (will_retry or is_auth_error)
if should_update_counts and download_stream_current is not None:
download_stream_current += int(response.http_response.headers.get('Content-Length', 0))
if data_stream_total is None:
content_range = response.http_response.headers.get('Content-Range')
if content_range:
data_stream_total = int(content_range.split(' ', 1)[1].split('/', 1)[1])
else:
data_stream_total = download_stream_current
elif should_update_counts and upload_stream_current is not None:
upload_stream_current += int(response.http_request.headers.get('Content-Length', 0))
for pipeline_obj in [request, response]:
if hasattr(pipeline_obj, 'context'):
pipeline_obj.context['data_stream_total'] = data_stream_total
pipeline_obj.context['download_stream_current'] = download_stream_current
pipeline_obj.context['upload_stream_current'] = upload_stream_current
if response_callback:
response_callback(response)
request.context['response_callback'] = response_callback
return response
class StorageContentValidation(SansIOHTTPPolicy):
"""A simple policy that sends the given headers
with the request.
This will overwrite any headers already defined in the request.
"""
header_name = 'Content-MD5'
def __init__(self, **kwargs: Any) -> None: # pylint: disable=unused-argument
super(StorageContentValidation, self).__init__()
@staticmethod
def get_content_md5(data):
# Since HTTP does not differentiate between no content and empty content,
# we have to perform a None check.
data = data or b""
md5 = hashlib.md5() # nosec
if isinstance(data, bytes):
md5.update(data)
elif hasattr(data, 'read'):
pos = 0
try:
pos = data.tell()
except: # pylint: disable=bare-except
pass
for chunk in iter(lambda: data.read(4096), b""):
md5.update(chunk)
try:
data.seek(pos, SEEK_SET)
except (AttributeError, IOError) as exc:
raise ValueError("Data should be bytes or a seekable file-like object.") from exc
else:
raise ValueError("Data should be bytes or a seekable file-like object.")
return md5.digest()
def on_request(self, request: "PipelineRequest") -> None:
validate_content = request.context.options.pop('validate_content', False)
if validate_content and request.http_request.method != 'GET':
computed_md5 = encode_base64(StorageContentValidation.get_content_md5(request.http_request.data))
request.http_request.headers[self.header_name] = computed_md5
request.context['validate_content_md5'] = computed_md5
request.context['validate_content'] = validate_content
def on_response(self, request: "PipelineRequest", response: "PipelineResponse") -> None:
if response.context.get('validate_content', False) and response.http_response.headers.get('content-md5'):
computed_md5 = request.context.get('validate_content_md5') or \
encode_base64(StorageContentValidation.get_content_md5(response.http_response.body()))
if response.http_response.headers['content-md5'] != computed_md5:
raise AzureError((
f"MD5 mismatch. Expected value is '{response.http_response.headers['content-md5']}', "
f"computed value is '{computed_md5}'."),
response=response.http_response
)
class StorageRetryPolicy(HTTPPolicy):
"""
The base class for Exponential and Linear retries containing shared code.
"""
total_retries: int
"""The max number of retries."""
connect_retries: int
"""The max number of connect retries."""
retry_read: int
"""The max number of read retries."""
retry_status: int
"""The max number of status retries."""
retry_to_secondary: bool
"""Whether the secondary endpoint should be retried."""
def __init__(self, **kwargs: Any) -> None:
self.total_retries = kwargs.pop('retry_total', 10)
self.connect_retries = kwargs.pop('retry_connect', 3)
self.read_retries = kwargs.pop('retry_read', 3)
self.status_retries = kwargs.pop('retry_status', 3)
self.retry_to_secondary = kwargs.pop('retry_to_secondary', False)
super(StorageRetryPolicy, self).__init__()
def _set_next_host_location(self, settings: Dict[str, Any], request: "PipelineRequest") -> None:
"""
A function which sets the next host location on the request, if applicable.
:param Dict[str, Any]] settings: The configurable values pertaining to the next host location.
:param PipelineRequest request: A pipeline request object.
"""
if settings['hosts'] and all(settings['hosts'].values()):
url = urlparse(request.url)
# If there's more than one possible location, retry to the alternative
if settings['mode'] == LocationMode.PRIMARY:
settings['mode'] = LocationMode.SECONDARY
else:
settings['mode'] = LocationMode.PRIMARY
updated = url._replace(netloc=settings['hosts'].get(settings['mode']))
request.url = updated.geturl()
def configure_retries(self, request: "PipelineRequest") -> Dict[str, Any]:
body_position = None
if hasattr(request.http_request.body, 'read'):
try:
body_position = request.http_request.body.tell()
except (AttributeError, UnsupportedOperation):
# if body position cannot be obtained, then retries will not work
pass
options = request.context.options
return {
'total': options.pop("retry_total", self.total_retries),
'connect': options.pop("retry_connect", self.connect_retries),
'read': options.pop("retry_read", self.read_retries),
'status': options.pop("retry_status", self.status_retries),
'retry_secondary': options.pop("retry_to_secondary", self.retry_to_secondary),
'mode': options.pop("location_mode", LocationMode.PRIMARY),
'hosts': options.pop("hosts", None),
'hook': options.pop("retry_hook", None),
'body_position': body_position,
'count': 0,
'history': []
}
def get_backoff_time(self, settings: Dict[str, Any]) -> float: # pylint: disable=unused-argument
""" Formula for computing the current backoff.
Should be calculated by child class.
:param Dict[str, Any] settings: The configurable values pertaining to the backoff time.
:returns: The backoff time.
:rtype: float
"""
return 0
def sleep(self, settings, transport):
backoff = self.get_backoff_time(settings)
if not backoff or backoff < 0:
return
transport.sleep(backoff)
def increment(
self, settings: Dict[str, Any],
request: "PipelineRequest",
response: Optional["PipelineResponse"] = None,
error: Optional[AzureError] = None
) -> bool:
"""Increment the retry counters.
:param Dict[str, Any] settings: The configurable values pertaining to the increment operation.
:param PipelineRequest request: A pipeline request object.
:param Optional[PipelineResponse] response: A pipeline response object.
:param Optional[AzureError] error: An error encountered during the request, or
None if the response was received successfully.
:returns: Whether the retry attempts are exhausted.
:rtype: bool
"""
settings['total'] -= 1
if error and isinstance(error, ServiceRequestError):
# Errors when we're fairly sure that the server did not receive the
# request, so it should be safe to retry.
settings['connect'] -= 1
settings['history'].append(RequestHistory(request, error=error))
elif error and isinstance(error, ServiceResponseError):
# Errors that occur after the request has been started, so we should
# assume that the server began processing it.
settings['read'] -= 1
settings['history'].append(RequestHistory(request, error=error))
else:
# Incrementing because of a server error like a 500 in
# status_forcelist and a the given method is in the allowlist
if response:
settings['status'] -= 1
settings['history'].append(RequestHistory(request, http_response=response))
if not is_exhausted(settings):
if request.method not in ['PUT'] and settings['retry_secondary']:
self._set_next_host_location(settings, request)
# rewind the request body if it is a stream
if request.body and hasattr(request.body, 'read'):
# no position was saved, then retry would not work
if settings['body_position'] is None:
return False
try:
# attempt to rewind the body to the initial position
request.body.seek(settings['body_position'], SEEK_SET)
except (UnsupportedOperation, ValueError):
# if body is not seekable, then retry would not work
return False
settings['count'] += 1
return True
return False
def send(self, request):
retries_remaining = True
response = None
retry_settings = self.configure_retries(request)
while retries_remaining:
try:
response = self.next.send(request)
if is_retry(response, retry_settings['mode']) or is_checksum_retry(response):
retries_remaining = self.increment(
retry_settings,
request=request.http_request,
response=response.http_response)
if retries_remaining:
retry_hook(
retry_settings,
request=request.http_request,
response=response.http_response,
error=None)
self.sleep(retry_settings, request.context.transport)
continue
break
except AzureError as err:
if isinstance(err, AzureSigningError):
raise
retries_remaining = self.increment(
retry_settings, request=request.http_request, error=err)
if retries_remaining:
retry_hook(
retry_settings,
request=request.http_request,
response=None,
error=err)
self.sleep(retry_settings, request.context.transport)
continue
raise err
if retry_settings['history']:
response.context['history'] = retry_settings['history']
response.http_response.location_mode = retry_settings['mode']
return response
[docs]
class ExponentialRetry(StorageRetryPolicy):
"""Exponential retry."""
initial_backoff: int
"""The initial backoff interval, in seconds, for the first retry."""
increment_base: int
"""The base, in seconds, to increment the initial_backoff by after the
first retry."""
random_jitter_range: int
"""A number in seconds which indicates a range to jitter/randomize for the back-off interval."""
def __init__(
self, initial_backoff: int = 15,
increment_base: int = 3,
retry_total: int = 3,
retry_to_secondary: bool = False,
random_jitter_range: int = 3,
**kwargs: Any
) -> None:
"""
Constructs an Exponential retry object. The initial_backoff is used for
the first retry. Subsequent retries are retried after initial_backoff +
increment_power^retry_count seconds.
:param int initial_backoff:
The initial backoff interval, in seconds, for the first retry.
:param int increment_base:
The base, in seconds, to increment the initial_backoff by after the
first retry.
:param int retry_total:
The maximum number of retry attempts.
:param bool retry_to_secondary:
Whether the request should be retried to secondary, if able. This should
only be enabled of RA-GRS accounts are used and potentially stale data
can be handled.
:param int random_jitter_range:
A number in seconds which indicates a range to jitter/randomize for the back-off interval.
For example, a random_jitter_range of 3 results in the back-off interval x to vary between x+3 and x-3.
"""
self.initial_backoff = initial_backoff
self.increment_base = increment_base
self.random_jitter_range = random_jitter_range
super(ExponentialRetry, self).__init__(
retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs)
[docs]
def get_backoff_time(self, settings: Dict[str, Any]) -> float:
"""
Calculates how long to sleep before retrying.
:param Dict[str, Any]] settings: The configurable values pertaining to get backoff time.
:returns:
A float indicating how long to wait before retrying the request,
or None to indicate no retry should be performed.
:rtype: float
"""
random_generator = random.Random()
backoff = self.initial_backoff + (0 if settings['count'] == 0 else pow(self.increment_base, settings['count']))
random_range_start = backoff - self.random_jitter_range if backoff > self.random_jitter_range else 0
random_range_end = backoff + self.random_jitter_range
return random_generator.uniform(random_range_start, random_range_end)
[docs]
class LinearRetry(StorageRetryPolicy):
"""Linear retry."""
initial_backoff: int
"""The backoff interval, in seconds, between retries."""
random_jitter_range: int
"""A number in seconds which indicates a range to jitter/randomize for the back-off interval."""
def __init__(
self, backoff: int = 15,
retry_total: int = 3,
retry_to_secondary: bool = False,
random_jitter_range: int = 3,
**kwargs: Any
) -> None:
"""
Constructs a Linear retry object.
:param int backoff:
The backoff interval, in seconds, between retries.
:param int retry_total:
The maximum number of retry attempts.
:param bool retry_to_secondary:
Whether the request should be retried to secondary, if able. This should
only be enabled of RA-GRS accounts are used and potentially stale data
can be handled.
:param int random_jitter_range:
A number in seconds which indicates a range to jitter/randomize for the back-off interval.
For example, a random_jitter_range of 3 results in the back-off interval x to vary between x+3 and x-3.
"""
self.backoff = backoff
self.random_jitter_range = random_jitter_range
super(LinearRetry, self).__init__(
retry_total=retry_total, retry_to_secondary=retry_to_secondary, **kwargs)
[docs]
def get_backoff_time(self, settings: Dict[str, Any]) -> float:
"""
Calculates how long to sleep before retrying.
:param Dict[str, Any]] settings: The configurable values pertaining to the backoff time.
:returns:
A float indicating how long to wait before retrying the request,
or None to indicate no retry should be performed.
:rtype: float
"""
random_generator = random.Random()
# the backoff interval normally does not change, however there is the possibility
# that it was modified by accessing the property directly after initializing the object
random_range_start = self.backoff - self.random_jitter_range \
if self.backoff > self.random_jitter_range else 0
random_range_end = self.backoff + self.random_jitter_range
return random_generator.uniform(random_range_start, random_range_end)
class StorageBearerTokenCredentialPolicy(BearerTokenCredentialPolicy):
""" Custom Bearer token credential policy for following Storage Bearer challenges """
def __init__(self, credential: "TokenCredential", audience: str, **kwargs: Any) -> None:
super(StorageBearerTokenCredentialPolicy, self).__init__(credential, audience, **kwargs)
def on_challenge(self, request: "PipelineRequest", response: "PipelineResponse") -> bool:
try:
auth_header = response.http_response.headers.get("WWW-Authenticate")
challenge = StorageHttpChallenge(auth_header)
except ValueError:
return False
scope = challenge.resource_id + DEFAULT_OAUTH_SCOPE
self.authorize_request(request, scope, tenant_id=challenge.tenant_id)
return True