#!/usr/bin/python3

import sys
import os
import time

pg_hba_conf = '/var/lib/pgsql/data/pg_hba.conf'
pg_hba_bak = pg_hba_conf + '-' + '-'.join([str(el).zfill(2) for el in list(time.localtime())][:6])

dbconf = {}
dborder = ['replication', 'all']

# - "local" is a Unix-domain socket
# - "host" is a TCP/IP socket (encrypted or not)
# - "hostssl" is a TCP/IP socket that is SSL-encrypted
# - "hostnossl" is a TCP/IP socket that is not SSL-encrypted
# - "hostgssenc" is a TCP/IP socket that is GSSAPI-encrypted
# - "hostnogssenc" is a TCP/IP socket that is not GSSAPI-encrypted
typeorder = ['local', 'host', 'hostssl', 'hostnossl', 'hostgssenc', 'hostnogssenc']


if not os.path.exists(pg_hba_conf):
    sys.stderr.write("pg_hba.conf does not exist\n")
    sys.exit(1)
if os.path.exists(pg_hba_bak):
    sys.stderr.write("pg_hba.conf backup file already exists\n")
    sys.exit(1)

os.rename(pg_hba_conf, pg_hba_bak)


for line in open(pg_hba_bak).readlines():
    line = line.strip()
    if not line or line.startswith('#'): continue
    vals = list(filter(None, line.replace("\t", " ").split(' ')))
    if vals[1] not in dbconf:
        dbconf[vals[1]] = {}
        if vals[1] not in dborder:
            dborder.insert(0, vals[1])
    if vals[0] not in dbconf[vals[1]]:
        dbconf[vals[1]][vals[0]] = []
    dbconf[vals[1]][vals[0]].append(vals)

with open(pg_hba_conf, "w") as pghba:
    for db in dborder:
        for type in typeorder:
            if db in dbconf and type in dbconf[db]:
                user_all_entries = []
                for entry in dbconf[db][type]:
                    if entry[2] != 'all':  # if user is all, it should be always the last
                        pghba.write("\t".join(entry))
                        pghba.write("\n")
                    else:
                        user_all_entries.insert(0, entry)
                for user_entry in user_all_entries:
                    pghba.write("\t".join(user_entry))
                    pghba.write("\n")

