from loopinterpreter import interpret
import unittest
from unittest import mock


def input_exit(prompt):
    return "EXIT"


def input_continue(prompt):
    return ""


def init_without_tokens(self, regex_to_token, program):
    self.regex_to_token = {}
    self.program = program
    self.current_position = 0


class LOOPInterpreterTest(unittest.TestCase):
    def test_assignment_default_zero(self):
        self.assertEqual(0, interpret('x0:=x0 + 0'))
        self.assertEqual(0, interpret('x0:=x1 + 0'))
        self.assertEqual(0, interpret('x0:=x2 + 0'))

    def test_assignment_non_negative(self):
        self.assertEqual(0, interpret('x0:=x0-1'))
        self.assertEqual(0, interpret('x0:=x1-1'))
        self.assertEqual(0, interpret('x0:=x2-6'))

    def test_assignment_number(self):
        self.assertEqual(5, interpret("x0:=5"))
        self.assertEqual(2, interpret("x0:=2"))
        self.assertEqual(3, interpret('x0:=3'))

    def test_assignment_variable(self):
        self.assertEqual(1, interpret('x0:=x0+1'))
        self.assertEqual(4, interpret('x0:= 5; x0:=x0-1'))
        self.assertEqual(1, interpret('x0:=x1-1; x0:=x0+1'))

    def test_assignment_wrong_syntax(self):
        with self.assertRaises(SyntaxError):
            interpret('x1:=x2')
        with self.assertRaises(SyntaxError):
            interpret('x1:=0+x2')
        with self.assertRaises(SyntaxError):
            interpret('x5:=-1+x4')
        with self.assertRaises(SyntaxError):
            interpret('x5:=-x3+x1')
        with self.assertRaises(SyntaxError):
            interpret('x5:=x1-x3')
        with self.assertRaises(SyntaxError):
            interpret('x2:=x1+x4')
        with self.assertRaises(SyntaxError):
            interpret('x2:=x1+2;')
        with self.assertRaises(SyntaxError):
            interpret('x1:=c')
        with self.assertRaises(SyntaxError):
            interpret('xi:=2')
        with self.assertRaises(SyntaxError):
            interpret('x0:=xj+1')

    def test_loop_assignment(self):
        self.assertEqual(1, interpret('x1:=1; LOOP x1 DO x0:=1 END'))
        self.assertEqual(4, interpret('x1:=2; LOOP x1 DO x0:=x0 + 2 END'))

    def test_loop_empty_assignment(self):
        self.assertEqual(0, interpret('LOOP x1 DO x0:=1 END'))
        self.assertEqual(0, interpret('x2:=2;LOOP x1 DO x0:=x2+1 END'))

    def test_loop_nested_assignment(self):
        self.assertEqual(6, interpret('x1:=3; LOOP x1 DO x2:=x2+1; LOOP x2 DO x0:=x0+1 END END'))
        self.assertEqual(3, interpret('x1:=3; x2:=3; LOOP x1 DO x2:=x2-1; LOOP x2 DO x0:=x0+1 END END'))

    def test_loop_forbidden_identifier(self):
        with self.assertRaises(SyntaxError):
            interpret('x1:=1; LOOP x1 DO x1:=x1+1 END')
        with self.assertRaises(SyntaxError):
            interpret('x1:=1; LOOP x1 DO x2:=x1 + 2 END')

    def test_loop_empty_forbidden_identifier(self):
        with self.assertRaises(SyntaxError):
            interpret('LOOP x2 DO x2:=2 END')
        with self.assertRaises(SyntaxError):
            interpret('LOOP x1 DO x2:=x1 - 2 END')

    def test_loop_nested_forbidden_identifier(self):
        with self.assertRaises(SyntaxError):
            interpret('x1:=2; LOOP x1 DO LOOP x1 DO x0:=x0+1 END END')
        with self.assertRaises(SyntaxError):
            interpret('x1:=1; x2:=2 LOOP x1 DO LOOP x2 DO x1:=2 END END')
        with self.assertRaises(SyntaxError):
            interpret('x1:=1; x2:=2 LOOP x1 DO LOOP x2 DO x2:=2 END END')
        with self.assertRaises(SyntaxError):
            interpret('x1:=1; x2:=2 LOOP x1 DO LOOP x2 DO x0:=x2+2 END END')
        with self.assertRaises(SyntaxError):
            interpret('x1:=1; x2:=2 LOOP x1 DO LOOP x2 DO x0:=x1-2 END END')

    def test_loop_nested_empty_forbidden_identifier(self):
        with self.assertRaises(SyntaxError):
            interpret('LOOP x1 DO LOOP x2 DO x2:=2 END END')
        with self.assertRaises(SyntaxError):
            interpret('LOOP x1 DO LOOP x2 DO x0:=x2+2 END END')
        with self.assertRaises(SyntaxError):
            interpret('LOOP x1 DO LOOP x2 DO x0:=x1 + 0 END END')
        with self.assertRaises(SyntaxError):
            interpret('LOOP x1 DO LOOP x2 DO x1:=2 END END')
        with self.assertRaises(SyntaxError):
            interpret('LOOP x1 DO LOOP x1 DO x2:=2 END END')

    def test_loop_wrong_syntax(self):
        with self.assertRaises(SyntaxError):
            interpret('LOOP 2 DO x2:=5 END')
        with self.assertRaises(SyntaxError):
            interpret('x1:=1; LOOP x1 DO x2:=5; END')
        with self.assertRaises(SyntaxError):
            interpret('x1:=1; LOOP x1 DO; x2:=5 END')
        with self.assertRaises(SyntaxError):
            interpret('x1:=1; LOOP x1 DO x2:=5 END;')

    def test_assignment_with_loop(self):
        self.assertEqual(2, interpret('x0:=2; LOOP x0 DO x1:=x1+1 END; x0:=x1+0'))
        self.assertEqual(1, interpret('x1:=x1+1; LOOP x0 DO x1:=x1+1 END; x0:=x1+0'))

    def test_syntax_unnecessary_semicolon(self):
        with self.assertRaises(SyntaxError):
            interpret('LOOP x0 DO x1:=x1+1 END;')
        with self.assertRaises(SyntaxError):
            interpret('LOOP x0 DO x1:=x1+1;; x1:=x1+1 END')
        with self.assertRaises(SyntaxError):
            interpret('x1:=x1+1;; x1:=x1+1')
        with self.assertRaises(SyntaxError):
            interpret(';x1:=x1+1')

    def test_syntax_unnecessary_end(self):
        with self.assertRaises(SyntaxError):
            interpret('LOOP x0 DO x1:=x1+1 END END')
        with self.assertRaises(SyntaxError):
            interpret('LOOP x0 DO x1:=x1+1 END; x1:=x1+1 END')
        with self.assertRaises(SyntaxError):
            interpret('x1:=x1+1; END x1:=x1+1')
        with self.assertRaises(SyntaxError):
            interpret('END x1:=x1+1')

    def test_syntax_missing_semicolon(self):
        with self.assertRaises(SyntaxError):
            interpret('x0:=2 LOOP x0 DO x1:=x1+1 END')
        with self.assertRaises(SyntaxError):
            interpret('LOOP x0 DO x1:=x1+1 x1:=x1+1 END')
        with self.assertRaises(SyntaxError):
            interpret('x0:=2; LOOP x0 DO x1:=x1+1 x1:=x1+1 END')
        with self.assertRaises(SyntaxError):
            interpret('LOOP x0 DO x1:=x1+1 END x0:=x1+0')
        with self.assertRaises(SyntaxError):
            interpret('x0:=2; LOOP x0 DO x1:=x1+1 END x0:=x1+0')

    def test_syntax_missing_do(self):
        with self.assertRaises(SyntaxError):
            interpret('LOOP x1 x2:=2 END')
        with self.assertRaises(SyntaxError):
            interpret('x1:=2; LOOP x1 x2:=2 END')
        with self.assertRaises(SyntaxError):
            interpret('LOOP x0 DO LOOP x1 x2:=2 END END')
        with self.assertRaises(SyntaxError):
            interpret('x0:=1; LOOP x0 DO LOOP x1 x2:=2 END END')
        with self.assertRaises(SyntaxError):
            interpret('x0:=1; x1:=2; LOOP x0 DO LOOP x1 x2:=2 END END')

    def test_syntax_missing_end(self):
        with self.assertRaises(SyntaxError):
            interpret('LOOP x0 DO LOOP x1 DO x2:=2')
        with self.assertRaises(SyntaxError):
            interpret('x0:=5; LOOP x0 DO LOOP x1 DO x2:=2')
        with self.assertRaises(SyntaxError):
            interpret('x0:=4; x1:=7; LOOP x0 DO LOOP x1 DO x2:=2')
        with self.assertRaises(SyntaxError):
            interpret('LOOP x0 DO LOOP x1 DO x2:=2 END')
        with self.assertRaises(SyntaxError):
            interpret('x0:=2; LOOP x0 DO LOOP x1 DO x2:=2 END')
        with self.assertRaises(SyntaxError):
            interpret('x0:=2; x1:=3; LOOP x0 DO LOOP x1 DO x2:=2 END')
        with self.assertRaises(SyntaxError):
            interpret('x0 := 2; LOOP x0 DO x1 := 1; x2 := x2 + 1')
        with self.assertRaises(SyntaxError):
            interpret('LOOP x0 DO x1:=2; x2:=0')

    def test_syntax_missing_program(self):
        with self.assertRaises(SyntaxError):
            interpret('LOOP x0 DO END')
        with self.assertRaises(SyntaxError):
            interpret('x0:=2; LOOP x0 DO END')
        with self.assertRaises(SyntaxError):
            interpret('LOOP x0 DO LOOP x1 DO END END')
        with self.assertRaises(SyntaxError):
            interpret('LOOP x0 DO LOOP x1 DO x2:=2; END')

    def test_syntax_missing_operator(self):
        with self.assertRaises(SyntaxError):
            interpret('x0:=x1 2')
        with self.assertRaises(SyntaxError):
            interpret('LOOP x2 DO x0:=x1 2 END')
        with self.assertRaises(SyntaxError):
            interpret('x2:=3; LOOP x2 DO x1:=x1 2 END')
        with self.assertRaises(SyntaxError):
            interpret('LOOP x2 DO LOOP x3 DO x1:=x1 2 END END')

    def test_syntax_missing_equals(self):
        with self.assertRaises(SyntaxError):
            interpret('x1 2')
        with self.assertRaises(SyntaxError):
            interpret('x1 x2+2')
        with self.assertRaises(SyntaxError):
            interpret('LOOP x0 DO x1 2 END')
        with self.assertRaises(SyntaxError):
            interpret('x0:=2; LOOP x0 DO x1 2 END')
        with self.assertRaises(SyntaxError):
            interpret('LOOP x0 DO x1 x2+3 END')
        with self.assertRaises(SyntaxError):
            interpret('x0:=2; LOOP x0 DO x1 x2-1 END')

    def test_syntax_missing_identifier(self):
        with self.assertRaises(SyntaxError):
            interpret("x1:=; LOOP x1 DO x2:=2 END")
        with self.assertRaises(SyntaxError):
            interpret("LOOP x1 DO x2:= END")
        with self.assertRaises(SyntaxError):
            interpret("LOOP x1 DO x2:=x0+ END")
        with self.assertRaises(SyntaxError):
            interpret('LOOP x0 DO LOOP DO x1:=x2+0 END END')

    def test_newlines(self):
        self.assertEqual(5, interpret('''x2:=3;
        x0:=x2+2'''))
        self.assertEqual(2, interpret('x1:=x1-2;\n x0:=x1+2'))

    @mock.patch('loopinterpreter.input', side_effect=input_exit)
    def test_break_exit(self, custom_input):
        self.assertEqual(-1, interpret('x1:=2; BREAK x0:=2'))
        self.assertEqual(-1, interpret('LOOP x1 DO BREAK x2:= 2 END'))

    @mock.patch('loopinterpreter.input', side_effect=input_continue)
    def test_break_continue(self, custom_input):
        self.assertEqual(4, interpret('x1:=2; LOOP x1 DO x0:=x0+2 BREAK END'))

    @mock.patch('lexer.Lexer.__init__', init_without_tokens)
    def test_unknown_tokens(self):
        with self.assertRaises(SyntaxError):
            interpret('BLIBLABLUB')