# passerelle - © Entr'ouvert
# "safe" python eval

import ast
import collections
import hashlib
import types


class EvalError(SyntaxError):
    @classmethod
    def with_node(cls, msg, node, text=''):
        text = text or ast.unparse(node)
        return cls(
            msg, ('<string>', node.lineno, node.col_offset, text, node.end_lineno, node.end_col_offset)
        )


class StaticCheckError(EvalError):
    pass


class DynamicCheckError(EvalError):
    pass


class VerifyAndSafetize(ast.NodeVisitor):
    comprehension_depth = 0

    authorized = [
        ast.Expression,
        ast.Load,
        ast.Store,
        ast.Name,
        ast.BoolOp,
        ast.UnaryOp,
        ast.BinOp,
        ast.IfExp,
        ast.Index,
        ast.Compare,
        # constant data
        ast.Constant,
        # ast.Num,
        # ast.Str,
        ast.List,
        ast.Tuple,
        # operators
        ast.operator,
        ast.boolop,
        ast.cmpop,
        ast.unaryop,
        # complex
        ast.Subscript,
        ast.Call,
        ast.Attribute,
        ast.ListComp,
        # misc
        ast.comprehension,
        ast.Slice,
        ast.JoinedStr,
        ast.FormattedValue,
    ]

    forbidden = None

    authorized_functions = [
        'int',
        'str',
        'bool',
        'float',
        'len',
        'sorted',
    ]

    forbidden_attribute = {  # pylint: disable=duplicate-value
        # list modifying methods
        'append',
        'clear',
        'extend',
        'insert',
        'pop',
        'remove',
        'reverse',
        'sort',
        # dict modifying methods
        'clear',
        'pop',
        'popitem',
        'setdefault',
        'update',
        # set modifying methods
        'add',
        'clear',
        'discard',
        'pop',
        'remove',
        'update',
        'difference_update',
        'intersection_update',
        'symmetric_difference_update',
        # file descriptor modifying methods
        'close',
        'detach',
        'flush',
        'reconfigure',
        'truncate',
        'write',
        'writelines',
        # type.mro
        'mro',
    }

    def __init__(self, visible=None, max_comprehension_depth=2, authorized_functions=()):
        self.visible = visible
        self.max_comprehension_depth = max_comprehension_depth
        self.authorized_functions = self.authorized_functions[:] + list(authorized_functions)
        self.forbidden = {ast.Store}
        self.symcounter = 0
        self.context = []

    def gensym(self):
        self.symcounter += 1
        return f'__s{self.symcounter:06d}'

    def visit(self, node):
        is_syntax = hasattr(node, 'lineno')
        if is_syntax:
            self.context.append(node)
        try:
            return super().visit(node)
        finally:
            if is_syntax:
                self.context.pop()

    def visit_BinOp(self, node):
        if isinstance(node.op, ast.Pow):
            raise StaticCheckError.with_node('unauthorized pow operator', node=node)
        self.generic_visit(node)
        if isinstance(node.op, ast.Mult):
            return ast.Call(
                func=ast.Name(id='__multop', ctx=ast.Load()),
                args=[
                    node.left,
                    node.right,
                ],
                keywords=[],
            )
        return node

    def visit_comprehension(self, node):
        self.forbidden = {ast.Attribute, ast.Subscript}
        self.visit(node.target)
        self.forbidden = {ast.Store}

        self.comprehension_depth += 1

        if self.comprehension_depth > self.max_comprehension_depth:
            raise StaticCheckError.with_node(
                'unauthorized comprehension depth, more than {self.max_comprehension_depth}',
                node=self.context[-1],
            )

        self.visit(node.iter)
        for if_node in node.ifs:
            self.visit(if_node)

        var = self.gensym()
        node.target = ast.Tuple(
            elts=[
                ast.Name(var, ctx=ast.Store()),
                node.target,
            ],
            ctx=ast.Store(),
        )
        node.iter = ast.Call(
            func=ast.Name(id='__zip', ctx=ast.Load()),
            args=[
                ast.Name(id='__comprehension_counter', ctx=ast.Load()),
                node.iter,
            ],
            keywords=[],
        )
        return node

    def visit_Name(self, node):
        if self.visible and node.id not in self.visible:
            raise StaticCheckError.with_node(f'unknown name {node.id}', node=node)
        if node.id.startswith('_') or node.id in ['comprehension_counter']:
            raise StaticCheckError.with_node(f'unauthorized name {node.id}', node=node)
        return self.generic_visit(node)

    def visit_Attribute(self, node):
        if node.attr.startswith('_') or node.attr in self.forbidden_attribute:
            raise StaticCheckError.with_node(f'unauthorized name {node.attr}', node=node)
        return self.generic_visit(node)

    def visit_Call(self, node):
        if isinstance(node.func, ast.Name) and node.func.id not in self.authorized_functions:
            raise StaticCheckError.with_node(f'unauthorized call to {node.func.id}', node=node)
        return self.generic_visit(node)

    # increment comprehension_depth of .generators before handling .elt
    ast.ListComp._fields = ['generators', 'elt']

    def visit_ListComp(self, node):
        depth = self.comprehension_depth
        try:
            return self.generic_visit(node)
        finally:
            self.comprehension_depth = depth

    def generic_visit(self, node):
        if self.forbidden and isinstance(node, tuple(self.forbidden)):
            raise StaticCheckError.with_node('unauthorized {node}', node=node)
        if isinstance(node, tuple(self.authorized)):
            for field, value in ast.iter_fields(node):
                if isinstance(value, list):
                    new = []
                    for item in value:
                        if isinstance(item, ast.AST):
                            new.append(self.visit(item))
                elif isinstance(value, ast.AST):
                    new = self.visit(value)
                else:
                    continue
                setattr(node, field, new or value)
        else:
            raise StaticCheckError.with_node('unauthorized {node}', node=node)
        return node


MAX_CACHE_SIZE = 1000
MAX_ITERATION = 100
MAX_COMPREHENSION_DEPTH = 2

eval_cache = collections.OrderedDict()


def multop(left, right):
    if type(left) != type(right):  # pylint: disable=unidiomatic-typecheck
        raise DynamicCheckError('unauthorized multiplication')

    if {type(left), type(right)} - {int, float}:
        raise DynamicCheckError('unauthorized multiplication')

    if isinstance(left, int):
        if left.bit_length() > 64 or right.bit_length() > 64:
            raise DynamicCheckError('unauthorized multiplication')

    return left * right


def comprehension_counter():
    yield from range(MAX_ITERATION)
    raise DynamicCheckError(f'unauthorized, more than {MAX_ITERATION} iterations')


def safe_compile(
    expression, visible=None, max_comprehension_depth=MAX_COMPREHENSION_DEPTH, authorized_functions=()
):
    key = hashlib.md5(expression.encode() + hash(expression).to_bytes(16, signed=True)).digest()
    if key in eval_cache:
        compiled = eval_cache[key]
    else:
        tree = ast.parse(expression, mode='eval')
        visitor = VerifyAndSafetize(
            visible=visible,
            max_comprehension_depth=max_comprehension_depth,
            authorized_functions=authorized_functions,
        )
        new_tree = visitor.visit(tree)
        ast.fix_missing_locations(new_tree)
        compiled = compile(new_tree, '', mode='eval')

    # reset key in front of eval_cache to make eval_cache an LRU cache of
    # MAX_CACHE_SIZE size
    eval_cache[key] = compiled
    # keep size of eval_cache under MAX_CACHE_SIZE, and keep LRU property
    # by using .popitem()
    if len(eval_cache) > MAX_CACHE_SIZE:
        eval_cache.popitem(last=True)

    return compiled


def safe_eval(expression, kwargs, **compile_kwargs):
    if isinstance(expression, types.CodeType):
        compiled = expression
    else:
        compiled = safe_compile(expression, **compile_kwargs)
    return eval(  # pylint: disable=eval-used
        compiled,
        {
            **kwargs,
            '__builtins__': {},
            '__multop': multop,
            '__comprehension_counter': comprehension_counter(),
            '__zip': zip,
            'int': int,
            'str': str,
            'bool': bool,
            'float': float,
            'len': len,
            'sorted': sorted,
        },
    )
