from typed_ast import ast27
from typed_ast import ast3

def py2to3(ast):
    """Converts a typed Python 2.7 ast to a typed Python 3.5 ast.  The returned
        ast is a valid Python 3 ast with two exceptions:

        - `arg` objects may contain Tuple objects instead of just identifiers
           in the case of Python 2 function definitions/lambdas that use the tuple
           unpacking syntax.
        - `Raise` objects will have a `traceback` attribute added if the 3
           argument version of the Python 2 raise is used.


    Strange and Rare Uncovered Edge Cases:
        - Raise: if the second argument to a raise statement is a tuple, its
          contents are unpacked as arguments to the exception constructor.  This
          case is handled correctly if it's a literal tuple, but not if it's any
          other sort of tuple expression.
    """
    return _AST2To3().visit(ast)

def _copy_attributes(new_value, old_value):
    attrs = getattr(old_value, '_attributes', None)
    if attrs is not None:
        for attr in attrs:
            setattr(new_value, attr, getattr(old_value, attr))
    return new_value

class _AST2To3(ast27.NodeTransformer):
    # note: None, True, and False are *not* translated into NameConstants.
    def __init__(self):
        pass

    def visit(self, node):
        """Visit a node."""
        method = 'visit_' + node.__class__.__name__
        visitor = getattr(self, method, self.generic_visit)
        ret = _copy_attributes(visitor(node), node)
        return ret

    def maybe_visit(self, node):
        if node is not None:
            return self.visit(node)
        else:
            return None

    def generic_visit(self, node):
        class_name = node.__class__.__name__
        converted_class = getattr(ast3, class_name)
        new_node = converted_class()
        for field, old_value in ast27.iter_fields(node):
            if isinstance(old_value, (ast27.AST, list)):
                setattr(new_node, field, self.visit(old_value))
            else:
                setattr(new_node, field, old_value)
        return new_node


    def visit_list(self, l):
        return [self.visit(e) if isinstance(e, (ast27.AST, list)) else e for e in l]

    def visit_FunctionDef(self, n):
        new = self.generic_visit(n)
        new.returns = None
        return new

    def visit_ClassDef(self, n):
        new = self.generic_visit(n)
        new.keywords = []
        return new

    def visit_TryExcept(self, n):
        return ast3.Try(self.visit(n.body),
                        self.visit(n.handlers),
                        self.visit(n.orelse),
                        [])

    def visit_TryFinally(self, n):
        if len(n.body) == 1 and isinstance(n.body[0], ast27.TryExcept):
            new = self.visit(n.body[0])
            new.finalbody = self.visit(n.finalbody)
            return new
        else:
            return ast3.Try(self.visit(n.body),
                            [],
                            [],
                            self.visit(n.finalbody))


    def visit_ExceptHandler(self, n):
        if n.name is None:
            name = None
        elif isinstance(n.name, ast27.Name):
            name = n.name.id
        else:
            raise RuntimeError("'{}' has non-Name name.".format(ast27.dump(n)))

        return ast3.ExceptHandler(self.maybe_visit(n.type),
                                  name,
                                  self.visit(n.body))

    def visit_Print(self, n):
        keywords = []
        if n.dest is not None:
            keywords.append(ast3.keyword("file", self.visit(n.dest)))

        if not n.nl:
            keywords.append(ast3.keyword("end",
                                         ast3.Str(s=" ", kind='', lineno=n.lineno, col_offset=-1)))

        return ast3.Expr(ast3.Call(ast3.Name("print", ast3.Load(), lineno=n.lineno, col_offset=-1),
                                   self.visit(n.values),
                                   keywords,
                                   lineno=n.lineno, col_offset=-1))

    def visit_Raise(self, n):
        e = None
        if n.type is not None:
            e = self.visit(n.type)

            if n.inst is not None and not (isinstance(n.inst, ast27.Name) and n.inst.id == "None"):
                inst = self.visit(n.inst)
                if isinstance(inst, ast3.Tuple):
                    args = inst.elts
                else:
                    args = [inst]
                e = ast3.Call(e, args, [], lineno=e.lineno, col_offset=-1)

        ret = ast3.Raise(e, None)
        if n.tback is not None:
            ret.traceback = self.visit(n.tback)
        return ret

    def visit_Exec(self, n):
        new_globals = self.maybe_visit(n.globals)
        if new_globals is None:
            new_globals = ast3.Name("None", ast3.Load(), lineno=-1, col_offset=-1)
        new_locals = self.maybe_visit(n.locals)
        if new_locals is None:
            new_locals = ast3.Name("None", ast3.Load(), lineno=-1, col_offset=-1)

        return ast3.Expr(ast3.Call(ast3.Name("exec", ast3.Load(), lineno=n.lineno, col_offset=-1),
                                   [self.visit(n.body), new_globals, new_locals],
                                   [],
                                   lineno=n.lineno, col_offset=-1))

    # TODO(ddfisher): the name repr could be used locally as something else; disambiguate
    def visit_Repr(self, n):
        return ast3.Call(ast3.Name("repr", ast3.Load(), lineno=n.lineno, col_offset=-1),
                         [self.visit(n.value)],
                         [])

    # TODO(ddfisher): this will cause strange behavior on multi-item with statements with type comments
    def visit_With(self, n):
        return ast3.With([ast3.withitem(self.visit(n.context_expr), self.maybe_visit(n.optional_vars))],
                          self.visit(n.body),
                          n.type_comment)

    def visit_Call(self, n):
        args = self.visit(n.args)
        if n.starargs is not None:
            args.append(ast3.Starred(self.visit(n.starargs), ast3.Load(), lineno=n.starargs.lineno, col_offset=n.starargs.col_offset))

        keywords = self.visit(n.keywords)
        if n.kwargs is not None:
            keywords.append(ast3.keyword(None, self.visit(n.kwargs)))

        return ast3.Call(self.visit(n.func),
                         args,
                         keywords)

    # TODO(ddfisher): find better attributes to give Ellipses
    def visit_Ellipsis(self, n):
        # ellipses in Python 2 only exist as a slice index
        return ast3.Index(ast3.Ellipsis(lineno=-1, col_offset=-1))

    def visit_arguments(self, n):
        def convert_arg(arg, type_comment):
            if isinstance(arg, ast27.Name):
                v = arg.id
            elif isinstance(arg, ast27.Tuple):
                v = self.visit(arg)
            else:
                raise RuntimeError("'{}' is not a valid argument.".format(ast27.dump(arg)))
            return ast3.arg(v, None, type_comment, lineno=arg.lineno, col_offset=arg.col_offset)

        def get_type_comment(i):
            if i < len(n.type_comments) and n.type_comments[i] is not None:
                return n.type_comments[i]
            return None

        args = [convert_arg(arg, get_type_comment(i)) for i, arg in enumerate(n.args)]

        vararg = None
        if n.vararg is not None:
            vararg = ast3.arg(n.vararg,
                              None,
                              get_type_comment(len(args)),
                              lineno=-1, col_offset=-1)

        kwarg = None
        if n.kwarg is not None:
            kwarg = ast3.arg(n.kwarg,
                             None,
                             get_type_comment(len(args) + (0 if n.vararg is None else 1)),
                             lineno=-1, col_offset=-1)

        defaults = self.visit(n.defaults)

        return ast3.arguments(args,
                              vararg,
                              [],
                              [],
                              kwarg,
                              defaults)

    def visit_Str(self, s):
        if isinstance(s.s, bytes):
            return ast3.Bytes(s.s, s.kind)
        else:
            return ast3.Str(s.s, s.kind)

    def visit_Num(self, n):
        new = self.generic_visit(n)
        if new.n < 0:
            # Python 3 uses a unary - operator for negative literals.
            new.n = -new.n
            return ast3.UnaryOp(op=ast3.USub(),
                                operand=_copy_attributes(new, n))
        else:
            return new
