import sys
import argparse
import unittest
import sqlalchemy

class TestQnaCorrectness(unittest.TestCase):
    def __init__(self, *args, **kwargs):
        super(TestQnaCorrectness, self).__init__(*args, **kwargs)
        db_string = 'postgres://qna:{}@{}/qna'.format(args_input.annotator_db_password, args_input.annotator_db_host)
        self.db_engine = sqlalchemy.create_engine(db_string)

    def test_instance_segmentation_data(self):
        with self.db_engine.connect() as conn:
            # check last_assigned_rater for all state except STATE_NONE, STATE_QUEUED
            result = conn.execute('SELECT id FROM qna_table WHERE task_id = 63 AND state_id NOT IN (1,2) AND deleted = false AND last_assigned_rater IS NULL LIMIT 1')
            self.assertEqual(result.rowcount, 0)
            # check last_answered_ds for all state except STATE_NONE, STATE_QUEUED, STATE_PENDING
            result = conn.execute('SELECT id FROM qna_table WHERE task_id = 63 AND state_id NOT IN (1,2,3) AND deleted = false AND last_answered_ds IS NULL LIMIT 1')
            self.assertEqual(result.rowcount, 0)
            # check last_assigned_manager for all state except STATE_NONE, STATE_QUEUED, STATE_PENDING, STATE_ANSWERED, STATE_GOLDEN
            result = conn.execute('SELECT id FROM qna_table WHERE task_id = 63 AND state_id NOT IN (1,2,3,4,7) AND deleted = false AND last_assigned_manager IS NULL LIMIT 1')
            self.assertEqual(result.rowcount, 0)
            # check created_at
            result = conn.execute('SELECT id FROM qna_table WHERE task_id = 63 AND created_at IS NULL')
            self.assertEqual(result.rowcount, 0)

    def test_classification_data(self):
        with self.db_engine.connect() as conn:
            # check last_assigned_rater for all state except STATE_NONE, STATE_QUEUED
            result = conn.execute('SELECT id FROM qna_table WHERE task_id = 93 AND state_id NOT IN (1,2) AND deleted = false AND last_assigned_rater IS NULL LIMIT 1')
            self.assertEqual(result.rowcount, 0)
            # check last_answered_ds for all state except STATE_NONE, STATE_QUEUED, STATE_PENDING
            result = conn.execute('SELECT id FROM qna_table WHERE task_id = 93 AND state_id NOT IN (1,2,3) AND deleted = false AND last_answered_ds IS NULL LIMIT 1')
            self.assertEqual(result.rowcount, 0)
            # check last_assigned_manager for all state except STATE_NONE, STATE_QUEUED, STATE_PENDING, STATE_ANSWERED, STATE_GOLDEN
            result = conn.execute('SELECT id FROM qna_table WHERE task_id = 93 AND state_id NOT IN (1,2,3,4,7) AND deleted = false AND last_assigned_manager IS NULL LIMIT 1')
            self.assertEqual(result.rowcount, 0)
            # check created_at
            result = conn.execute('SELECT id FROM qna_table WHERE task_id = 93 AND created_at IS NULL')
            self.assertEqual(result.rowcount, 0)

    def test_track_data(self):
        with self.db_engine.connect() as conn:
            # check last_assigned_rater for all state except STATE_NONE, STATE_QUEUED
            result = conn.execute('SELECT id FROM qna_table WHERE task_id = 41 AND state_id NOT IN (1,2) AND deleted = false AND last_assigned_rater IS NULL LIMIT 1')
            self.assertEqual(result.rowcount, 0)
            # check last_answered_ds for all state except STATE_NONE, STATE_QUEUED, STATE_PENDING
            result = conn.execute('SELECT id FROM qna_table WHERE task_id = 41 AND state_id NOT IN (1,2,3) AND deleted = false AND last_answered_ds IS NULL LIMIT 1')
            self.assertEqual(result.rowcount, 0)
            # check last_assigned_manager for all state except STATE_NONE, STATE_QUEUED, STATE_PENDING, STATE_ANSWERED, STATE_GOLDEN
            result = conn.execute('SELECT id FROM qna_table WHERE task_id = 41 AND state_id NOT IN (1,2,3,4,7) AND deleted = false AND last_assigned_manager IS NULL LIMIT 1')
            self.assertEqual(result.rowcount, 0)
            # check created_at
            result = conn.execute('SELECT id FROM qna_table WHERE task_id = 41 AND created_at IS NULL')
            self.assertEqual(result.rowcount, 0)

    def test_traffic_light_data(self):
        with self.db_engine.connect() as conn:
            # check last_assigned_rater for all state except STATE_NONE, STATE_QUEUED
            result = conn.execute('SELECT id FROM qna_table WHERE task_id = 180 AND state_id NOT IN (1,2) AND deleted = false AND last_assigned_rater IS NULL LIMIT 1')
            self.assertEqual(result.rowcount, 0)
            # check last_answered_ds for all state except STATE_NONE, STATE_QUEUED, STATE_PENDING
            result = conn.execute('SELECT id FROM qna_table WHERE task_id = 180 AND state_id NOT IN (1,2,3) AND deleted = false AND last_answered_ds IS NULL LIMIT 1')
            self.assertEqual(result.rowcount, 0)
            # check last_assigned_manager for all state except STATE_NONE, STATE_QUEUED, STATE_PENDING, STATE_ANSWERED, STATE_GOLDEN
            result = conn.execute('SELECT id FROM qna_table WHERE task_id = 180 AND state_id NOT IN (1,2,3,4,7) AND deleted = false AND last_assigned_manager IS NULL LIMIT 1')
            self.assertEqual(result.rowcount, 0)
            # check created_at
            result = conn.execute('SELECT id FROM qna_table WHERE task_id = 180 AND created_at IS NULL')
            self.assertEqual(result.rowcount, 0)

    def test_detector_data(self):
        with self.db_engine.connect() as conn:
            # check last_assigned_rater for all state except STATE_NONE, STATE_QUEUED AND STATE_CONFIRMED
            result = conn.execute('SELECT id FROM qna_table WHERE task_id = 44 AND state_id NOT IN (1,2,6) AND deleted = false AND last_assigned_rater IS NULL LIMIT 1')
            self.assertEqual(result.rowcount, 0)
            # check last_answered_ds for all state except STATE_NONE, STATE_QUEUED, STATE_PENDING AND STATE_CONFIRMED
            result = conn.execute('SELECT id FROM qna_table WHERE task_id = 44 AND state_id NOT IN (1,2,3,6) AND deleted = false AND last_answered_ds IS NULL LIMIT 1')
            self.assertEqual(result.rowcount, 0)
            # check last_assigned_manager for all state except STATE_NONE, STATE_QUEUED, STATE_PENDING, STATE_ANSWERED, STATE_GOLDEN AND STATE_CONFIRMED
            result = conn.execute('SELECT id FROM qna_table WHERE task_id = 44 AND state_id NOT IN (1,2,3,4,6,7) AND deleted = false AND last_assigned_manager IS NULL LIMIT 1')
            self.assertEqual(result.rowcount, 0)
            # check STATE_CONFIRMED for data after migration
            result = conn.execute("SELECT id FROM qna_table WHERE task_id = 44 AND state_id = 6 AND deleted = false AND (last_assigned_manager IS NULL OR last_assigned_rater IS NULL OR last_answered_ds IS NULL) AND to_char(to_timestamp((created_at)/1e9)::date, 'YYYY/MM/DD') > '2019/09/18' LIMIT 1")
            self.assertEqual(result.rowcount, 0)
            # check created_at
            result = conn.execute('SELECT id FROM qna_table WHERE task_id = 44 AND created_at IS NULL')
            self.assertEqual(result.rowcount, 0)

    def test_priority(self):
        # test priority for all task
        with self.db_engine.connect() as conn:
            result = conn.execute('SELECT id FROM qna_table WHERE priority IS NULL LIMIT 1')
            self.assertEqual(result.rowcount, 0)

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--annotator_db_host',
            default='annotator-db-v1.cp4rxnpmuhoe.us-west-1.rds.amazonaws.com:5432')
    parser.add_argument('--annotator_db_password', required=True)
    args_input, unknown = parser.parse_known_args()
    sys.argv[1:] = unknown
    unittest.main()
