# passerelle - uniform access to multiple data sources and services
# Copyright (C) 2019 Entr'ouvert
#
# This program is free software: you can redistribute it and/or modify it
# under the terms of the GNU Affero General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

import datetime
import hashlib
import inspect
import json

from django.apps import apps
from django.conf.urls import url
from django.core.cache import cache
from django.core.exceptions import PermissionDenied
from django.contrib.auth import logout as auth_logout
from django.contrib.auth import views as auth_views
from django.db import transaction
from django.db.models import TextField
from django.db.models.functions import Cast
from django.http import HttpResponse, HttpResponseRedirect, Http404
from django.views.decorators.csrf import csrf_exempt
from django.views.generic import (
    RedirectView, View, TemplateView, CreateView, DeleteView, UpdateView,
    DetailView, ListView)
from django.views.generic.detail import SingleObjectMixin
from django.conf import settings
from django.shortcuts import resolve_url
from django.core.urlresolvers import reverse
from django.utils.timezone import make_aware
from django.utils.translation import ugettext_lazy as _
from django.utils.encoding import force_bytes, force_text
from django.forms.models import modelform_factory
from django.forms.widgets import ClearableFileInput
from django.utils.six.moves.urllib.parse import quote

from dateutil import parser as date_parser
from jsonschema import validate, ValidationError

from passerelle.base.models import BaseResource, ResourceLog
from passerelle.compat import json_loads
from passerelle.utils.jsonresponse import APIError
from passerelle.utils.json import unflatten

from .utils import to_json, is_authorized
from .forms import GenericConnectorForm

if 'mellon' in settings.INSTALLED_APPS:
    from mellon.utils import get_idps
else:
    def get_idps():
        return []


def get_all_apps():
    return [x for x in apps.get_models() if issubclass(x, BaseResource) and x.is_enabled()]


class LoginView(auth_views.LoginView):
    def dispatch(self, request, *args, **kwargs):
        if any(get_idps()):
            if 'next' not in request.GET:
                return HttpResponseRedirect(resolve_url('mellon_login'))
            return HttpResponseRedirect(resolve_url('mellon_login') + '?next='
                                        + quote(request.GET.get('next')))
        return super(LoginView, self).dispatch(request, *args, **kwargs)

login = LoginView.as_view()


def logout(request, next_page=None):
    if any(get_idps()):
        return HttpResponseRedirect(resolve_url('mellon_logout'))
    auth_logout(request)
    if next_page is not None:
        next_page = resolve_url(next_page)
    else:
        next_page = '/'
    return HttpResponseRedirect(next_page)


def menu_json(request):
    label = _('Web Services')
    json_str = json.dumps([
        {
            'label': force_text(label),
            'slug': 'passerelle',
            'url': request.build_absolute_uri(reverse('manage-home'))
        }
    ])
    content_type = 'application/json'
    for variable in ('jsonpCallback', 'callback'):
        if variable in request.GET:
            json_str = '%s(%s);' % (request.GET[variable], json_str)
            content_type = 'application/javascript'
            break
    response = HttpResponse(content_type=content_type)
    response.write(json_str)
    return response


class HomePageView(RedirectView):
    pattern_name = 'manage-home'
    permanent = False


class ManageView(TemplateView):
    template_name = 'passerelle/manage.html'

    def get_context_data(self, **kwargs):
        context = super(ManageView, self).get_context_data(**kwargs)
        # get all app instances
        context['apps'] = []
        for app in get_all_apps():
            context['apps'].extend(app.objects.all())
        context['apps'].sort(key=lambda x: x.title.lower())
        return context


class ManageAddView(TemplateView):
    template_name = 'passerelle/manage_add.html'

    def get_context_data(self, **kwargs):
        context = super(ManageAddView, self).get_context_data(**kwargs)
        context['apps'] = get_all_apps()
        context['apps'].sort(key=lambda x: x.get_verbose_name())
        return context


class GenericConnectorMixin(object):
    exclude_fields = ('slug', 'users')

    def get_connector(self, **kwargs):
        return kwargs.get('connector')

    def init_stuff(self, request, *args, **kwargs):
        connector = self.get_connector(**kwargs)
        for app in apps.get_app_configs():
            if not hasattr(app, 'get_connector_model'):
                continue
            if app.get_connector_model().get_connector_slug() == connector:
                break
        else:
            raise Http404()

        self.model = app.get_connector_model()
        if hasattr(app, 'get_form_class'):
            self.form_class = app.get_form_class()
        else:
            self.form_class = modelform_factory(
                self.model,
                form=GenericConnectorForm,
                exclude=self.exclude_fields)
            for field in self.form_class.base_fields.values():
                if isinstance(field.widget, ClearableFileInput):
                    field.widget.template_with_initial = ''\
                        '%(initial_text)s: %(initial)s '\
                        '%(clear_template)s<br />%(input_text)s: %(input)s'

    def dispatch(self, request, *args, **kwargs):
        self.init_stuff(request, *args, **kwargs)
        return super(GenericConnectorMixin, self).dispatch(
            request, *args, **kwargs)


class GenericConnectorView(GenericConnectorMixin, DetailView):
    def get_context_data(self, slug=None, **kwargs):
        context = super(GenericConnectorView, self).get_context_data(**kwargs)
        context['has_check_status'] = not hasattr(
            context['object'].check_status, 'not_implemented')
        return context

    def get_template_names(self):
        template_names = super(DetailView, self).get_template_names()[:]
        if self.model.manager_view_template_name:
            template_names.append(self.model.manager_view_template_name)
        template_names.append('passerelle/manage/service_view.html')
        return template_names


class GenericCreateConnectorView(GenericConnectorMixin, CreateView):
    template_name = 'passerelle/manage/service_form.html'
    exclude_fields = ('users',)  # slug not excluded

    def form_valid(self, form):
        with transaction.atomic():
            response = super(GenericCreateConnectorView, self).form_valid(form)
        self.object.availability()
        return response

    def init_stuff(self, request, *args, **kwargs):
        super(GenericCreateConnectorView, self).init_stuff(request, *args, **kwargs)
        # tell JS to prepopulate 'slug' field using the 'title' field
        self.form_class.base_fields['title'].widget.attrs['data-slug-sync'] = 'slug'


class GenericEditConnectorView(GenericConnectorMixin, UpdateView):
    template_name = 'passerelle/manage/service_form.html'

    def form_valid(self, form):
        with transaction.atomic():
            response = super(GenericEditConnectorView, self).form_valid(form)
        self.object.availability()
        return response


class GenericDeleteConnectorView(GenericConnectorMixin, DeleteView):
    template_name = 'passerelle/manage/service_confirm_delete.html'

    def get_success_url(self):
        return reverse('manage-home')


class GenericViewLogsConnectorView(GenericConnectorMixin, ListView):
    template_name = 'passerelle/manage/service_logs.html'
    paginate_by = 25

    def get_context_data(self, **kwargs):
        context = super(GenericViewLogsConnectorView, self).get_context_data(**kwargs)
        context['object'] = self.get_object()
        context['query'] = self.request.GET.get('q') or ''
        return context

    def get_object(self):
        return self.model.objects.get(slug=self.kwargs['slug'])

    def get_queryset(self):
        qs = ResourceLog.objects.filter(
            appname=self.kwargs['connector'],
            slug=self.kwargs['slug']).order_by('-timestamp')
        query = self.request.GET.get('q')
        if query:
            try:
                date = date_parser.parse(query, dayfirst=True)
            except Exception:
                qs = qs.annotate(
                    text_extra=Cast('extra', TextField())
                ).filter(text_extra__icontains=query)
            else:
                date = make_aware(date)
                if date.hour == 0 and date.minute == 0 and date.second == 0:
                    # just a date: display all events for that date
                    qs = qs.filter(timestamp__gte=date,
                                   timestamp__lte=date + datetime.timedelta(days=1))
                elif date.second == 0:
                    # without seconds: display all events in this minute
                    qs = qs.filter(timestamp__gte=date,
                                   timestamp__lte=date + datetime.timedelta(seconds=60))
                else:
                    # display all events in the same second
                    qs = qs.filter(timestamp__gte=date,
                                   timestamp__lte=date + datetime.timedelta(seconds=1))
        return qs


class GenericLogView(GenericConnectorMixin, DetailView):
    template_name = 'passerelle/manage/log.html'

    def get_context_data(self, **kwargs):
        context = super(GenericLogView, self).get_context_data(**kwargs)
        try:
            context['logline'] = ResourceLog.objects.get(
                pk=self.kwargs['log_pk'],
                appname=self.kwargs['connector'],
                slug=self.kwargs['slug'])
        except ResourceLog.DoesNotExist:
            raise Http404()
        return context


class WrongParameter(Exception):
    http_status = 400
    log_error = False

    def __init__(self, missing, extra):
        self.missing = missing
        self.extra = extra

    def __str__(self):
        s = []
        if self.missing:
            s.append('missing parameters: %s.' % ', '.join(map(repr, self.missing)))
        if self.extra:
            s.append('extra parameters: %s.' % ', '.join(map(repr, self.extra)))
        return ' '.join(s)


class InvalidParameterValue(Exception):
    http_status = 400
    log_error = False

    def __init__(self, parameter_name):
        self.parameter_name = parameter_name

    def __str__(self):
        return 'invalid value for parameter "%s"' % self.parameter_name

IGNORED_PARAMS = ('apikey', 'signature', 'nonce', 'algo', 'timestamp', 'orig', 'jsonpCallback',
                  'callback', '_', 'raise', 'debug', 'decode', 'format')


class GenericEndpointView(GenericConnectorMixin, SingleObjectMixin, View):
    def get_params(self, request, parameters=None, *args, **kwargs):
        d = {}
        for key in request.GET:
            # ignore authentication keys and JSONP params
            if key in IGNORED_PARAMS:
                continue
            d[key] = request.GET[key]
        other_params = kwargs.get('other_params', {})
        for key in other_params:
            if other_params[key] is None:
                continue
            if not d.get(key):
                d[key] = other_params[key]
        for parameter_info in self.endpoint.endpoint_info.get_params():
            # check and convert parameter values
            parameter = parameter_info['name']
            if parameter not in d:
                continue
            if parameter_info.get('type') in ('bool', 'boolean'):
                if d[parameter].lower() in ('true', 'on'):
                    d[parameter] = True
                elif d[parameter].lower() in ('false', 'off'):
                    d[parameter] = False
                else:
                    raise InvalidParameterValue(parameter)
            elif parameter_info.get('type') in ('int', 'integer'):
                try:
                    d[parameter] = int(d[parameter])
                except ValueError:
                    raise InvalidParameterValue(parameter)
            elif parameter_info.get('type') == 'float':
                d[parameter] = d[parameter].replace(',', '.')
                try:
                    d[parameter] = float(d[parameter])
                except ValueError:
                    raise InvalidParameterValue(parameter)

        if request.method == 'POST' and self.endpoint.endpoint_info.post:
            request_body = self.endpoint.endpoint_info.post.get('request_body', {})
            if 'application/json' in request_body.get('schema', {}):
                json_schema = request_body['schema']['application/json']
                must_unflatten = hasattr(json_schema, 'items') and json_schema.get('unflatten', False)
                merge_extra = hasattr(json_schema, 'items') and json_schema.get('merge_extra', False)
                pre_process = hasattr(json_schema, 'items') and json_schema.get('pre_process')
                try:
                    data = json_loads(request.body)
                except ValueError as e:
                    raise APIError("could not decode body to json: %s" % e, http_status=400)
                if must_unflatten:
                    data = unflatten(data)
                if merge_extra and hasattr(data, 'items'):
                    data.update(data.pop('extra', {}))
                if pre_process is not None:
                    pre_process(self.endpoint.__self__, data)
                try:
                    validate(data, json_schema)
                except ValidationError as e:
                    error_msg = e.message
                    if e.path:
                        error_msg = '%s: %s' % ('/'.join(map(str, e.path)), error_msg)
                    raise APIError(error_msg, http_status=400)
                d['post_data'] = data

        return d

    @csrf_exempt
    def dispatch(self, request, *args, **kwargs):
        self.init_stuff(request, *args, **kwargs)
        self.connector = self.get_object()
        self.endpoint = None
        for name, method in inspect.getmembers(self.connector, inspect.ismethod):
            if not hasattr(method, 'endpoint_info'):
                continue
            if not method.endpoint_info.name == kwargs.get('endpoint'):
                continue
            if method.endpoint_info.pattern:
                pattern = url(method.endpoint_info.pattern, method)
                match = pattern.resolve(kwargs.get('rest') or '')
                if match:
                    self.endpoint = method
                    break
            else:
                self.endpoint = method
        if not self.endpoint:
            raise Http404()
        if kwargs.get('endpoint') == 'up' and hasattr(self.connector.check_status, 'not_implemented'):
            # hide automatic up endpoint if check_status method is not implemented
            raise Http404()
        return super(GenericEndpointView, self).dispatch(request, *args, **kwargs)

    def _allowed_methods(self):
        return [x.upper() for x in self.endpoint.endpoint_info.methods]

    def check_perms(self, request):
        perm = self.endpoint.endpoint_info.perm
        if not perm:
            return True
        return is_authorized(request, self.connector, perm)

    def perform(self, request, *args, **kwargs):
        if request.method.lower() not in self.endpoint.endpoint_info.methods:
            return self.http_method_not_allowed(request, *args, **kwargs)
        if not self.check_perms(request):
            raise PermissionDenied()
        argspec = inspect.getargspec(self.endpoint)
        parameters = argspec.args[2:]
        params = self.get_params(request, parameters=parameters, *args, **kwargs)
        try:
            inspect.getcallargs(self.endpoint, request, **params)
        except TypeError:
            # prevent errors if using name of an ignored parameter in an endpoint argspec
            ignored = set(parameters) & set(IGNORED_PARAMS)
            assert not ignored, 'endpoint %s has ignored parameter %s' % (request.path, ignored)
            extra, missing = [], []
            for i, arg in enumerate(parameters):
                # check if the argument is optional, i.e. it has a default value
                if len(parameters) - i <= len(argspec.defaults or []):
                    continue
                if arg not in params:
                    missing.append(arg)
            for key in params:
                if key not in argspec.args:
                    extra.append(key)
            raise WrongParameter(missing, extra)

        # auto log request's inputs
        connector_name, endpoint_name = kwargs['connector'], kwargs['endpoint']
        url = request.get_full_path()
        payload = request.body[:self.connector.logging_parameters.requests_max_size]
        try:
            payload = payload.decode('utf-8')
        except UnicodeDecodeError:
            payload = '<BINARY PAYLOAD>'
        self.connector.logger.info('endpoint %s %s (%r) ' %
                              (request.method, url, payload),
                              extra={
                                  'request': request,
                                  'connector': connector_name,
                                  'connector_endpoint': endpoint_name,
                                  'connector_endpoint_method': request.method,
                                  'connector_endpoint_url': url,
                                  'connector_payload': payload
                              })

        params = self.get_params(request, *args, **kwargs)
        if request.method == 'GET' and self.endpoint.endpoint_info.cache_duration:
            cache_key = hashlib.md5(
                force_bytes(repr(self.get_object().slug) + repr(self.endpoint) + repr(params))
            ).hexdigest()
            result = cache.get(cache_key)
            if result is not None:
                return result

        result = self.endpoint(request, **params)
        if request.method == 'GET' and self.endpoint.endpoint_info.cache_duration:
            cache.set(cache_key, result, self.endpoint.endpoint_info.cache_duration)
        return result

    def get(self, request, *args, **kwargs):
        if self.endpoint.endpoint_info.pattern:
            pattern = url(self.endpoint.endpoint_info.pattern, self.endpoint)
            match = pattern.resolve(kwargs.get('rest') or '')
            if not match:
                raise Http404()
            kwargs['other_params'] = match.kwargs
        elif kwargs.get('rest'):
            raise Http404()
        return to_json(logger=self.connector.logger)(self.perform)(request, *args, **kwargs)

    def post(self, request, *args, **kwargs):
        return self.get(request, *args, **kwargs)

    def patch(self, request, *args, **kwargs):
        return self.get(request, *args, **kwargs)

    def delete(self, request, *args, **kwargs):
        return self.get(request, *args, **kwargs)


class GenericExportConnectorView(GenericConnectorMixin, DetailView):

    def get(self, request, *args, **kwargs):
        response = HttpResponse(content_type='application/json')
        json.dump({'resources': [self.get_object().export_json()]}, response, indent=2)
        return response
