# authentic2-wallonie-connect - Authentic2 plugin for the Wallonie Connect usecase
# Copyright (C) 2019 Entr'ouvert
#
# This program is free software: you can redistribute it and/or modify it
# under the terms of the GNU Affero General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

import functools
import json

from django.db import transaction
from django.utils import six

from authentic2_idp_oidc.models import OIDCClient, OIDCClaim, generate_uuid
from authentic2.a2_rbac.models import Role, OrganizationalUnit
from authentic2.custom_user.models import User


from django.core.management.base import BaseCommand


class DryRun(Exception):
    pass


DEFAULT_CLAIMS = [
    {"name": "given_name", "value": "first_name", "scopes": ["profile"]},
    {"name": "family_name", "value": "last_name", "scopes": ["profile"]},
    {"name": "email", "value": "email", "scopes": ["email"]},
]


def dryrun(func):
    @functools.wraps(func)
    def f(*args, **kwargs):
        try:
            with transaction.atomic():
                return func(*args, **kwargs)
        except DryRun:
            pass

    return f


class Command(BaseCommand):
    help = "Create validation requests"

    def add_arguments(self, parser):
        parser.add_argument("--no-dry-run")
        parser.add_argument("paths", nargs="+")

    def handle(self, paths, no_dry_run=False, verbosity=1, **options):
        self.no_dry_run = no_dry_run
        self.verbosity = verbosity
        for path in paths:
            with open(path) as fd:
                contents = json.load(fd)
            if contents.get("data", None):
                contents = contents.get("data")
            if not isinstance(contents, list):
                contents = [contents]
            self.do(contents=contents)
            contents = {"data": contents}
            with open(path, "w") as fd:
                json.dump(contents, fd, indent=4)

    def info(self, *args, **kwargs):
        if self.verbosity >= 1:
            self.stdout.write(*args, **kwargs)

    @dryrun
    def do(self, contents):
        for content in contents:
            locality = content["locality"]

            self.info("Locality %s" % locality["name"], ending=" ")

            ou, created = OrganizationalUnit.objects.get_or_create(
                slug=locality["slug"], defaults={"name": locality["name"]}
            )

            if not created:
                if ou.name != locality["name"]:
                    ou.name = locality["name"]
                    ou.save()
                    self.info(self.style.SUCCESS("UPDATED"))
                else:
                    self.info("unchanged")
            else:
                self.info(self.style.SUCCESS("CREATED"))

            services = {}

            content_services = content.get("services", [])
            assert isinstance(content_services, list)

            for service in content_services:
                name = service["name"]
                self.info("Service %s " % name, ending=" ")
                slug = service["slug"]
                client_id = service.get("client_id")
                client_secret = service.get("client_secret")
                frontchannel_logout_uri = service["frontchannel_logout_uri"]
                assert isinstance(frontchannel_logout_uri, six.text_type)
                post_logout_redirect_uris = service.get("post_logout_redirect_uris", [])
                assert isinstance(post_logout_redirect_uris, list)
                open_to_all = service.get("open_to_all", False)
                redirect_uris = service.get("redirect_uris", [])
                assert isinstance(redirect_uris, list)
                has_api_access = service.get("has_api_access", False)
                assert isinstance(has_api_access, bool)
                identifier_policy = service.get("identifier_policy", OIDCClient.POLICY_UUID)
                assert isinstance(identifier_policy, int)
                idtoken_algo = OIDCClient.ALGO_HMAC
                if "idtoken_algo" in service:
                    idtoken_algo = getattr(
                        OIDCClient, "ALGO_" + service["idtoken_algo"].upper()
                    )

                oidc_client, created = OIDCClient.objects.get_or_create(
                    slug=service["slug"],
                    ou=ou,
                    defaults={
                        "name": name,
                        "client_id": client_id or generate_uuid(),
                        "client_secret": client_secret or generate_uuid(),
                        "frontchannel_logout_uri": frontchannel_logout_uri,
                        "post_logout_redirect_uris": "\n".join(
                            post_logout_redirect_uris
                        ),
                        "redirect_uris": "\n".join(redirect_uris),
                        "has_api_access": has_api_access,
                        "identifier_policy": identifier_policy,
                        "idtoken_algo": idtoken_algo,
                    },
                )
                services[slug] = {"oidc_client": oidc_client}
                if not created:
                    modified = False
                    for key in (
                        "name",
                        "client_id",
                        "client_secret",
                        "frontchannel_logout_uri",
                        "post_logout_redirect_uris",
                        "redirect_uris",
                        "has_api_access",
                        "idtoken_algo",
                        "identifier_policy",
                    ):
                        value = locals()[key]
                        if value is None:
                            continue
                        if getattr(oidc_client, key) != value:
                            setattr(oidc_client, key, value)
                            modified = True
                    if modified:
                        oidc_client.save()
                        self.info(self.style.SUCCESS("MODIFIED"))
                    else:
                        self.info("unchanged")
                else:
                    self.info(self.style.SUCCESS("CREATED"))
                service["client_id"] = oidc_client.client_id
                service["client_secret"] = oidc_client.client_secret

                if not open_to_all:
                    access_role, created = Role.objects.get_or_create(
                        slug=slug, ou=ou, defaults={"name": slug}
                    )
                    if not created and access_role.slug != slug:
                        access_role.name = slug
                        access_role.save()
                    services[slug]["access_role"] = access_role
                else:
                    Role.objects.filter(slug=slug, ou=ou).delete()
                claims = service.get("claims", DEFAULT_CLAIMS)
                assert isinstance(claims, list), "claims must be a list of dic"
                claim_set = set()

                for claim in claims:
                    assert isinstance(claim, dict), "claims must be a list of dict"
                    name = claim["name"]
                    value = claim["value"]
                    scopes = claim["scopes"]
                    assert name and isinstance(
                        name, six.string_types
                    ), "claim's name must be a non-empty string"
                    assert value and isinstance(
                        name, six.string_types
                    ), "claim's value must be a non-empty string"
                    assert (
                        scopes
                        and isinstance(scopes, list)
                        and all(isinstance(x, six.string_types) for x in scopes)
                    ), "claim's scope must be a non-empty list of strings"

                    oidc_claim, created = OIDCClaim.objects.get_or_create(
                        client=oidc_client,
                        name=claim["name"],
                        value=claim["value"],
                        defaults={"scopes": " ".join(scopes)},
                    )
                    if not created:
                        if set(oidc_claim.get_scopes()) != set(scopes):
                            oidc_claim.scopes = " ".join(scopes)
                            oidc_claim.save()

            content_users = content.get("users", [])
            assert isinstance(content_users, list)

            password = None
            email = None
            first_name = None
            last_name = None
            for content_user in content_users:
                required = ["email", "username"]
                data = {}
                for string_key in (
                    "email",
                    "first_name",
                    "last_name",
                    "password",
                    "username",
                ):
                    assert string_key in content_user, "missing key " + string_key
                    value = content_user[string_key]
                    assert isinstance(value, six.text_type), (
                        "invalid type for key " + string_key
                    )
                    if string_key in required:
                        assert value, (
                            "missing value for key " + string_key + " %s" % content_user
                        )
                    data[string_key] = content_user[string_key]
                assert "password" in data
                assert data["password"].startswith("{SSHA}")
                data["password"] = "plonesha1$%s" % data["password"]
                uuid = content_user.get("uuid") or None
                assert uuid is None or (
                    isinstance(uuid, six.text_type) and uuid
                ), "invalid uuid %s %s" % (uuid, content_user)
                allowed_services = content_user.get("allowed_services", [])
                assert isinstance(allowed_services, list)

                defaults = data.copy()
                if uuid is not None:
                    self.info("User %s-%s" % (data["username"], uuid), ending=" ")
                    kwargs = {"uuid": uuid, "ou": ou, "defaults": defaults}
                else:
                    self.info("User %s" % data["username"], ending=" ")
                    kwargs = {
                        "username": defaults.pop("username"),
                        "ou": ou,
                        "defaults": defaults,
                    }
                user, created = User.objects.get_or_create(**kwargs)
                if created:
                    self.info(self.style.SUCCESS("CREATED"))
                else:
                    modified = False
                    for key in defaults:
                        if getattr(user, key) != defaults[key]:
                            setattr(user, key, defaults[key])
                            modified = True
                    if modified:
                        user.save()
                        self.info(self.style.SUCCESS("MODIFIED"))
                    else:
                        self.info("unchanged")
                content_user["uuid"] = user.uuid
                for service_slug in allowed_services:
                    role = services[service_slug]["access_role"]
                    service = services[service_slug]["oidc_client"]
                    self.info("Access to service %s" % service.name, ending=" ")
                    if role.members.filter(pk=user.pk).exists():
                        self.info("unchanged")
                    else:
                        role.members.add(user)
                        self.info(self.style.SUCCESS("ADDED"))
                for service_slug in set(services) - set(allowed_services):
                    role = services[service_slug]["access_role"]
                    service = services[service_slug]["oidc_client"]
                    self.info("Access to service %s" % service.name, ending=" ")
                    if role.members.filter(pk=user.pk).exists():
                        role.members.remove(user)
                        self.info(self.style.SUCCESS("REMOVED"))
                    else:
                        self.info("unchanged")

        if self.no_dry_run:
            return
        raise DryRun
