# passerelle - uniform access to multiple data sources and services
# Copyright (C) 2018 Entr'ouvert
#
# This program is free software: you can redistribute it and/or modify it
# under the terms of the GNU Affero General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

import base64
import hashlib
import hmac
import time
from urllib import parse as urlparse
from uuid import uuid4

from django.utils.encoding import force_bytes, force_str
from requests.auth import AuthBase


class HawkAuth(AuthBase):
    def __init__(self, id, key, algorithm='sha256', ext=''):
        self.id = id.encode('utf-8')
        self.key = key.encode('utf-8')
        self.algorithm = algorithm
        self.timestamp = str(int(time.time()))
        self.nonce = uuid4().hex
        self.ext = ext

    def get_payload_hash(self, req):
        p_hash = hashlib.new(self.algorithm)
        p_hash.update(force_bytes('hawk.1.payload\n'))
        p_hash.update(force_bytes(req.headers.get('Content-Type', '') + '\n'))
        p_hash.update(force_bytes(req.body or ''))
        p_hash.update(force_bytes('\n'))
        return force_str(base64.b64encode(p_hash.digest()))

    def get_authorization_header(self, req):
        url_parts = urlparse.urlparse(req.url)
        uri = url_parts.path
        if url_parts.query:
            uri += '?' + url_parts.query
        if url_parts.port is None:
            if url_parts.scheme == 'http':
                port = '80'
            elif url_parts.scheme == 'https':
                port = '443'
        hash = self.get_payload_hash(req)
        data = [
            'hawk.1.header',
            self.timestamp,
            self.nonce,
            req.method.upper(),
            uri,
            url_parts.hostname,
            port,
            hash,
            self.ext,
            '',
        ]
        digestmod = getattr(hashlib, self.algorithm)
        result = hmac.new(force_bytes(self.key), force_bytes('\n'.join(data)), digestmod)
        mac = force_str(base64.b64encode(result.digest()))
        authorization = 'Hawk id="%s", ts="%s", nonce="%s", hash="%s", mac="%s"' % (
            force_str(self.id),
            self.timestamp,
            self.nonce,
            hash,
            mac,
        )
        if self.ext:
            authorization += ', ext="%s"' % self.ext
        return authorization

    def __call__(self, r):
        r.headers['Authorization'] = self.get_authorization_header(r)
        return r


class HttpBearerAuth(AuthBase):
    def __init__(self, token):
        self.token = token

    def __eq__(self, other):
        return self.token == getattr(other, 'token', None)

    def __ne__(self, other):
        return not self == other

    def __call__(self, r):
        r.headers['Authorization'] = 'Bearer ' + self.token
        return r
