# passerelle - uniform access to multiple data sources and services
# Copyright (C) 2019 Entr'ouvert
#
# This program is free software: you can redistribute it and/or modify it
# under the terms of the GNU Affero General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.

import datetime

from django.utils import six

import isodate
from lxml import etree as ET
from zeep.utils import qname_attr


def parse_bool(boolean):
    return boolean.lower() == 'true'


def parse_date(date):
    if isinstance(date, datetime.date):
        return date
    return datetime.datetime.strptime('%Y-%m-%d', date).date()


XSD = 'http://www.w3.org/2001/XMLSchema'
ns = {'xsd': XSD}

SCHEMA = ET.QName(XSD, 'schema')
ANNOTATION = ET.QName(XSD, 'annotation')
ELEMENT = ET.QName(XSD, 'element')
ATTRIBUTE = ET.QName(XSD, 'attribute')
COMPLEX_TYPE = ET.QName(XSD, 'complexType')
SIMPLE_TYPE = ET.QName(XSD, 'simpleType')
COMPLEX_CONTENT = ET.QName(XSD, 'complexContent')
EXTENSION = ET.QName(XSD, 'extension')
RESTRICTION = ET.QName(XSD, 'restriction')
SEQUENCE = ET.QName(XSD, 'sequence')
CHOICE = ET.QName(XSD, 'choice')
ALL = ET.QName(XSD, 'all')
BOOLEAN = ET.QName(XSD, 'boolean')
STRING = ET.QName(XSD, 'string')
DATE = ET.QName(XSD, 'date')
INT = ET.QName(XSD, 'int')
INTEGER = ET.QName(XSD, 'integer')
DATE_TIME = ET.QName(XSD, 'dateTime')
ANY_TYPE = ET.QName(XSD, 'anyType')

TYPE_CASTER = {
    BOOLEAN: parse_bool,
    STRING: six.text_type,
    DATE: parse_date,
    INT: int,
    INTEGER: int,
    DATE_TIME: isodate.parse_datetime,
    ANY_TYPE: lambda v: v,
}


class Schema(object):
    def __init__(self):
        self.types = {}
        self.elements = {}
        self.target_namespace = None
        self.element_form_default = 'qualified'
        self.attribute_form_default = 'unqualified'
        self.nsmap = {}

    def visit(self, root):
        assert root.tag == SCHEMA
        assert set(root.attrib) <= set(['targetNamespace', 'elementFormDefault', 'attributeFormDefault']), (
            'unsupported schema attributes %s' % root.attrib
        )
        self.target_namespace = root.get('targetNamespace')
        self.element_form_default = root.get('elementFormDefault', self.element_form_default)
        self.attribute_form_default = root.get('attributeFormDefault', self.attribute_form_default)
        self.nsmap = root.nsmap
        self.reverse_nsmap = {value: key for key, value in self.nsmap.items()}

        # first pass
        for node in root:
            if node.tag == COMPLEX_TYPE:
                name = qname_attr(node, 'name')
                assert name, 'unsupported top complexType without name'
                self.types[name] = {}
            elif node.tag == ELEMENT:
                name = qname_attr(node, 'name')
                assert name, 'unsupported top element without name'
                self.elements[name] = {}
            elif node.tag == SIMPLE_TYPE:
                name = qname_attr(node, 'name')
                assert name, 'unsupported top simpleType without name'
                self.types[name] = {}
            else:
                raise NotImplementedError('unsupported top element %s' % node)

        # second pass
        for node in root:
            if node.tag == COMPLEX_TYPE:
                d = self.visit_complex_type(node)
                target = self.types
            elif node.tag == SIMPLE_TYPE:
                d = self.visit_simple_type(node)
                target = self.types
            elif node.tag == ELEMENT:
                d = self.visit_element(node)
                target = self.elements
            else:
                raise NotImplementedError
            if not d['name'].namespace and self.target_namespace:
                d['name'] = ET.QName(self.target_namespace, d['name'].localname)
            target[d['name']] = d

    def visit_simple_type(self, node):
        # ignore annotations
        children = [child for child in node if child.tag != ANNOTATION]
        d = {}
        name = qname_attr(node, 'name')
        if name:
            d['name'] = name
        assert len(children) == 1, list(node)
        assert children[0].tag == RESTRICTION
        xsd_type = qname_attr(children[0], 'base')
        assert xsd_type == STRING
        d['type'] = STRING
        return d

    def visit_complex_content(self, node):
        d = {}
        name = qname_attr(node, 'name')
        if name:
            d['name'] = name
        assert len(node) == 1
        assert node[0].tag == EXTENSION
        xsd_type = qname_attr(node[0], 'base')
        d['type'] = xsd_type
        return d

    def visit_complex_type(self, node):
        # ignore annotations
        children = [child for child in node if child.tag != ANNOTATION]
        if children and children[0].tag in (SEQUENCE, CHOICE, ALL, COMPLEX_CONTENT):
            if children[0].tag == SEQUENCE:
                d = self.visit_sequence(children[0])
            elif children[0].tag == CHOICE:
                d = self.visit_choice(children[0])
            elif children[0].tag == ALL:
                d = self.visit_all(children[0])
            elif children[0].tag == COMPLEX_CONTENT:
                d = self.visit_complex_content(children[0])
            children = children[1:]
        else:
            d = {}
        for child in children:
            assert child.tag == ATTRIBUTE, 'unsupported complexType with child %s' % child
            name = qname_attr(child, 'name')
            assert name, 'attribute without a name %s' % ET.tostring(child)
            assert set(child.attrib) <= set(['use', 'type', 'name']), child.attrib
            attributes = d.setdefault('attributes', {})
            xsd_type = qname_attr(child, 'type')
            attributes[name] = {
                'name': name,
                'use': child.get('use', 'optional'),
                'type': xsd_type,
            }

        name = qname_attr(node, 'name')
        if name:
            d['name'] = name
        return d

    def visit_element(self, node, top=False):
        # ignore annotations
        assert set(node.attrib.keys()) <= set(['name', 'type', 'minOccurs', 'maxOccurs']), node.attrib
        children = [child for child in node if child.tag != ANNOTATION]
        # we handle elements with a name and one child, an anonymous complex type
        # or element without children referencing a complex type
        name = qname_attr(node, 'name')
        assert name is not None
        min_occurs = node.attrib.get('minOccurs') or 1
        max_occurs = node.attrib.get('maxOccurs') or 1
        d = {
            'name': name,
            'min_occurs': int(min_occurs),
            'max_occurs': max_occurs if max_occurs == 'unbounded' else int(max_occurs),
        }
        if len(children) == 1:
            ctype_node = children[0]
            assert ctype_node.tag == COMPLEX_TYPE
            assert ctype_node.attrib == {}
            d.update(self.visit_complex_type(ctype_node))
            return d
        elif len(children) == 0:
            xsd_type = qname_attr(node, 'type')
            if xsd_type is None:
                xsd_type = STRING
            d['type'] = xsd_type
            return d
        else:
            raise NotImplementedError('unsupported element with more than one children %s' % list(node))

    def visit_sequence(self, node):
        assert set(node.attrib) <= set(['maxOccurs']), node.attrib
        sequence = []

        for element_node in node:
            assert element_node.tag in (
                ELEMENT,
                CHOICE,
            ), 'unsupported sequence with child not an element or a choice %s' % ET.tostring(element_node)
            if element_node.tag == ELEMENT:
                sequence.append(self.visit_element(element_node))
            elif element_node.tag == CHOICE:
                sequence.append(self.visit_choice(element_node))

        d = {
            'sequence': sequence,
        }
        if 'maxOccurs' in node.attrib:
            d['max_occurs'] = node.get('maxOccurs', 1)
        return d

    def visit_all(self, node):
        return self.visit_sequence(node)

    def visit_choice(self, node):
        assert node.attrib == {}, 'unsupported choice with attributes %s' % node.attrib
        choice = []

        for element_node in node:
            assert element_node.tag == ELEMENT, 'unsupported sequence with child not an element %s' % node
            choice.append(self.visit_element(element_node))

        return {'choice': choice}

    def qname_display(self, name):
        if name.namespace in self.reverse_nsmap:
            name = '%s:%s' % (self.reverse_nsmap[name.namespace], name.localname)
        return six.text_type(name)

    def paths(self):
        roots = sorted(self.elements.keys())

        def helper(path, ctype, is_type=False):
            name = None
            if 'name' in ctype:
                name = ctype['name']
            max_occurs = ctype.get('max_occurs', 1)
            max_occurs = 3 if max_occurs == 'unbounded' else max_occurs
            if 'type' in ctype:
                if name and not is_type:
                    path = path + [name]
                xsd_type = ctype['type']
                if xsd_type in self.types:
                    sub_type = self.types[xsd_type]
                    for subpath in helper(path, sub_type, is_type=True):
                        yield subpath
                else:
                    if max_occurs > 1:
                        for i in range(max_occurs):
                            yield path[:-1] + [
                                ET.QName(name.namespace, name.localname + '_%d' % (i + 1))
                            ], xsd_type
                    yield path, xsd_type
            else:
                for extension in (
                    [''] if max_occurs == 1 else [''] + ['_%s' % i for i in list(range(1, max_occurs + 1))]
                ):
                    new_path = path
                    if name and not is_type:
                        new_path = new_path + [ET.QName(name.namespace, name.localname + extension)]
                    if 'sequence' in ctype:
                        for sub_ctype in ctype['sequence']:
                            for subpath in helper(new_path, sub_ctype):
                                yield subpath
                    elif 'choice' in ctype:
                        for sub_ctype in ctype['choice']:
                            for subpath in helper(new_path, sub_ctype):
                                yield subpath

        for root in roots:
            for path in helper([], self.elements[root]):
                yield path


@six.python_2_unicode_compatible
class Path(object):
    def __init__(self, path, xsd_type):
        assert path
        self.path = path
        self.xsd_type = xsd_type
        try:
            self.caster = TYPE_CASTER[xsd_type]
        except KeyError:
            raise KeyError(six.text_type(xsd_type))

    def resolve(self, root):
        def helper(node, path):
            if not path:
                return node
            else:
                for child in node:
                    if child.tag == path[0]:
                        return helper(child, path[1:])

        if root.tag != self.path[0]:
            return None
        child = helper(root, self.path[1:])
        if child is not None and child.text and not list(child):
            return self.caster(child.text)

    def __str__(self):
        return '.'.join(six.text_type(name) for name in self.path)
