# Copyright (c) 2013 Yubico AB
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or
# without modification, are permitted provided that the following
# conditions are met:
#
# 1. Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above
# copyright notice, this list of conditions and the following
# disclaimer in the documentation and/or other materials provided
# with the distribution.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
# FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
# COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
# INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
# LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
# ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.
"""Various utility functions.
This module contains various functions used throughout the rest of the project.
"""
from __future__ import annotations
from base64 import urlsafe_b64decode, urlsafe_b64encode
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hmac, hashes
from io import BytesIO
from dataclasses import fields, Field
from abc import abstractmethod
from typing import (
Union,
Optional,
Sequence,
Mapping,
Dict,
Any,
TypeVar,
Hashable,
get_type_hints,
overload,
Type,
)
import struct
import warnings
__all__ = [
"websafe_encode",
"websafe_decode",
"sha256",
"hmac_sha256",
"bytes2int",
"int2bytes",
]
LOG_LEVEL_TRAFFIC = 5
[docs]
def sha256(data: bytes) -> bytes:
"""Produces a SHA256 hash of the input.
:param data: The input data to hash.
:return: The resulting hash.
"""
h = hashes.Hash(hashes.SHA256(), default_backend())
h.update(data)
return h.finalize()
[docs]
def hmac_sha256(key: bytes, data: bytes) -> bytes:
"""Performs an HMAC-SHA256 operation on the given data, using the given key.
:param key: The key to use.
:param data: The input data to hash.
:return: The resulting hash.
"""
h = hmac.HMAC(key, hashes.SHA256(), default_backend())
h.update(data)
return h.finalize()
[docs]
def bytes2int(value: bytes) -> int:
"""Parses an arbitrarily sized integer from a byte string.
:param value: A byte string encoding a big endian unsigned integer.
:return: The parsed int.
"""
return int.from_bytes(value, "big")
[docs]
def int2bytes(value: int, minlen: int = -1) -> bytes:
"""Encodes an int as a byte string.
:param value: The integer value to encode.
:param minlen: An optional minimum length for the resulting byte string.
:return: The value encoded as a big endian byte string.
"""
ba = []
while value > 0xFF:
ba.append(0xFF & value)
value >>= 8
ba.append(value)
ba.extend([0] * (minlen - len(ba)))
return bytes(reversed(ba))
[docs]
def websafe_decode(data: Union[str, bytes]) -> bytes:
"""Decodes a websafe-base64 encoded string.
See: "Base 64 Encoding with URL and Filename Safe Alphabet" from Section 5
in RFC4648 without padding.
:param data: The input to decode.
:return: The decoded bytes.
"""
if isinstance(data, str):
data = data.encode("ascii")
else:
warnings.warn(
"Calling websafe_decode on a byte value is deprecated, "
"and will no longer be allowed starting in python-fido2 2.0",
DeprecationWarning,
)
data += b"=" * (-len(data) % 4)
return urlsafe_b64decode(data)
[docs]
def websafe_encode(data: bytes) -> str:
"""Encodes a byte string into websafe-base64 encoding.
:param data: The input to encode.
:return: The encoded string.
"""
return urlsafe_b64encode(data).replace(b"=", b"").decode("ascii")
class ByteBuffer(BytesIO):
"""BytesIO-like object with the ability to unpack values."""
def unpack(self, fmt: str):
"""Reads and unpacks a value from the buffer.
:param fmt: A struct format string yielding a single value.
:return: The unpacked value.
"""
s = struct.Struct(fmt)
return s.unpack(self.read(s.size))[0]
def read(self, size: Optional[int] = -1) -> bytes:
"""Like BytesIO.read(), but checks the number of bytes read and raises an error
if fewer bytes were read than expected.
"""
data = super().read(size)
if size is not None and size > 0 and len(data) != size:
raise ValueError(
"Not enough data to read (need: %d, had: %d)." % (size, len(data))
)
return data
_T = TypeVar("_T", bound=Hashable)
_S = TypeVar("_S", bound="_DataClassMapping")
class _DataClassMapping(Mapping[_T, Any]):
"""A data class with members also accessible as a Mapping."""
# TODO: This requires Python 3.9, and fixes the type errors we now ignore
# __dataclass_fields__: ClassVar[Dict[str, Field[Any]]]
def __post_init__(self):
hints = get_type_hints(type(self))
self._field_keys: Dict[_T, Field[Any]]
object.__setattr__(self, "_field_keys", {})
for f in fields(self): # type: ignore
self._field_keys[self._get_field_key(f)] = f
value = getattr(self, f.name)
if value is not None:
try:
value = self._parse_value(hints[f.name], value)
object.__setattr__(self, f.name, value)
except (TypeError, KeyError, ValueError):
raise ValueError(
f"Error parsing field {f.name} for {self.__class__.__name__}"
)
@classmethod
@abstractmethod
def _get_field_key(cls, field: Field) -> _T:
raise NotImplementedError()
def __iter__(self):
return (
k for k, f in self._field_keys.items() if getattr(self, f.name) is not None
)
def __len__(self):
return len(list(iter(self)))
def __getitem__(self, key):
f = self._field_keys[key]
value = getattr(self, f.name)
if value is None:
raise KeyError(key)
serialize = f.metadata.get("serialize")
if serialize:
return serialize(value)
if isinstance(value, Mapping) and not isinstance(value, dict):
return dict(value)
if isinstance(value, Sequence) and all(isinstance(v, Mapping) for v in value):
return [v if isinstance(v, dict) else dict(v) for v in value]
return value
@classmethod
def _parse_value(cls, t, value):
if Optional[t] == t: # Optional, get the type
t = t.__args__[0]
# Check if type is already correct
try:
if t is Any or isinstance(value, t):
return value
except TypeError:
pass
# Handle list of values
if issubclass(getattr(t, "__origin__", object), Sequence):
t = getattr(t, "__args__")[0]
return [cls._parse_value(t, v) for v in value]
# Handle Mappings
elif issubclass(getattr(t, "__origin__", object), Mapping) and isinstance(
value, Mapping
):
t_k, t_v = getattr(t, "__args__")
return {
cls._parse_value(t_k, k): cls._parse_value(t_v, v)
for k, v in value.items()
}
# Check if type has from_dict
from_dict = getattr(t, "from_dict", None)
if from_dict:
return from_dict(value)
# Convert to enum values, other wrappers
wrap = getattr(t, "__call__", None)
if wrap:
return wrap(value)
raise ValueError(f"Unparseable value of type {type(value)} for {t}")
@overload
@classmethod
def from_dict(cls: Type[_S], data: None) -> None: ...
@overload
@classmethod
def from_dict(cls: Type[_S], data: Mapping[_T, Any]) -> _S: ...
@classmethod
def from_dict(cls, data):
if data is None:
return None
if isinstance(data, cls):
return data
if not isinstance(data, Mapping):
raise TypeError(
f"{cls.__name__}.from_dict called with non-Mapping data of type"
f"{type(data)}"
)
kwargs = {}
hints = get_type_hints(cls)
for f in fields(cls): # type: ignore
key = cls._get_field_key(f)
value = data.get(key)
if value is None:
continue
deserialize = f.metadata.get("deserialize")
if deserialize:
value = deserialize(value)
else:
t = hints[f.name]
value = cls._parse_value(t, value)
kwargs[f.name] = value
return cls(**kwargs)
class _JsonDataObject(_DataClassMapping[str]):
"""A data class with members also accessible as a JSON-serializable Mapping."""
@classmethod
def _get_field_key(cls, field: Field) -> str:
name = field.metadata.get("name")
if name:
return name
parts = field.name.split("_")
return parts[0] + "".join(p.title() for p in parts[1:])
def __getitem__(self, key):
value = super().__getitem__(key)
if isinstance(value, bytes):
return websafe_encode(value)
return value
@classmethod
def _parse_value(cls, t, value):
if Optional[t] == t: # Optional, get the type
t2 = t.__args__[0]
else:
t2 = t
# bytes are encoded as websafe_b64 strings
if isinstance(t2, type) and issubclass(t2, bytes) and isinstance(value, str):
return websafe_decode(value)
return super()._parse_value(t, value)