"""
Helper functions for database migration scripts.

This module contains a variety of helper functions that can be imported to write
custom database migration scripts.  These scripts are often necessary to write
when a database schema changes.

Note: Don't forget to copy the sequences if you're copying over data from
another database!

TODO: We need to add functions to copy the values of sequences.
"""

__author__ = 'Martin Blais <blais@furius.ca>'


# stdlib imports
import re
from os.path import *
from itertools import izip
from subprocess import call, Popen, PIPE

# dbapi imports
import psycopg2 as dbapi


def ditch_database(dbname, user):
    """
    Drop and recreate the given database.
    """
    call(['dropdb', '-U', user, dbname])
    r = call(['createdb', '-U', user, dbname])
    if r != 0:
        raise RuntimeError("Could not create database '%s'" % dbname)


def dump_database(dbname, user):
    """
    Dump the database contents and return a readable pipe from which the binary
    dump can be read.
    """
    p = Popen(('pg_dump', '-i', '-U', user, '-F', 'c', '-v', '-b', '-c', '-C',
               dbname), stdout=PIPE)
    return p.stdout

    
def dump_schema(dbname, user):
    """
    Dump the schema for database 'dbname' and return a file object that can be
    read to obtain it.
    """
    p1 = Popen(['pg_dump', '-U', user,
                '--schema-only', '--schema', 'public', dbname],
               stdout=PIPE)
    return p1, p1.stdout

def copy_table(curs_from, curs_to, table, desttable=None,
               rename=None, ignore=None):
    """
    Naively copy a table's contents to another.

    'table' -> str: is the source table name.
    'desttable' -> str: is the destination table name.'
    'curs_from', 'curs_to' -> dbapi cursor: corresponding database cursors
    'rename' -> dict of str to str: column renaming
    """

    # If not destination table is specified, use the same name as the source
    if desttable is None:
        desttable = table

    # Select all the columns to see which they are
    curs_from.execute("""
        SELECT * FROM %s LIMIT 1
        """ % table)
    columns = [x[0] for x in curs_from.description]
    
    # Ignore columns
    if ignore is not None:
        columns = filter(lambda x: x not in ignore, columns)
        
    # Redo the select call with just the columns we want
    curs_from.execute("""
        SELECT %s FROM %s
        """ % (','.join(columns), table))
    columns = [x[0] for x in curs_from.description]

    # Rename columns
    if rename is not None:
        columns = [rename.get(col, col) for col in columns]
        
    # Copy the data.
    phold = ','.join(['%s'] * len(columns))
    for row in curs_from:
        curs_to.execute("""
            INSERT INTO %s (%s) VALUES (%s)
            """ % (table, ','.join(columns), phold), row)

def compare_table_data(curs_from, curs_to, table, desttable=None,
                       ignore=None):
    """
    Compare all the data found in the given tables.  Raises an exception if the
    files differ.
    """
    class Container: pass
    
    cfrom = Container()
    cfrom.curs = curs_from
    cfrom.table = table

    cto = Container()
    cto.curs = curs_to
    cto.table = desttable or cfrom.table

    # Compare the list of columns.
    for c in cfrom, cto:
        c.curs.execute(""" SELECT * FROM %s LIMIT 1 """ % c.table)
        c.cols = sorted([x[0] for x in c.curs.description])

        # Ignore columns
        if ignore is not None:
            c.cols = filter(lambda x: x not in ignore, c.cols)

    if cfrom.cols != cto.cols:
        raise RuntimeError("Columns differ:\n%s\n" %
                           pformat(zip(cfrom.cols, cto.cols)))

    # Compare the data.
    for c in cfrom, cto:
        c.curs.execute(""" SELECT %s FROM %s ORDER BY id""" %
                       (','.join(c.cols), c.table))

    if cfrom.curs.rowcount != cto.curs.rowcount:
        raise RuntimeError("Data length differs:\n%d\n%d\n" %
                           (cfrom.curs.rowcount, cto.curs.rowcount))
    
    for row1, row2 in izip(cfrom.curs, cto.curs):
        if row1 != row2:
            raise RuntimeError("Data differs:\n%s\n%s\n" % (row1, row2))

def parse_options():
    """
    Parse the options for the database.
    """
    import optparse
    parser = optparse.OptionParser(migration_init.__doc__.strip())

    parser.add_option('-D', '--recreate-original-db', action='store',
                      help=("Drop server DB and recreate using the given "
                            "dump file."))

    opts, args = parser.parse_args()
    
    if len(args) != 3:
        parser.error("You need to specify an original database name, "
                     "a destination/migrated database name, and a source "
                     "for the new schema to initialize the destination "
                     "database.")
    dbfrom, dbto, newschema = args

    mo = re.match('(file|db):(.+)$', newschema)
    if not mo:
        parser.error("Error: New schema name is invalid.")
    opts.newtyp, opts.newschema = mo.groups()

    if opts.recreate_original_db and not exists(opts.recreate_original_db):
        parser.error("Error: server dump file does not exist.")
        
    return opts, dbfrom, dbto

def recreate_database_from_dump(dbname, user, dumpfn):
    """
    Recreate a database from a dump file.
    """
    ditch_database(dbname, user)
    call(['pg_restore', '-d', dbname, dumpfn])


def migration_init(user, host='localhost'):
    """
    Migrate my database contents to the new database format.  The destination
    database is always dropped before running this script.  It is initialized
    either from a file or from dumping another (template) database schema.

      <this-script> [<options>] <original-db> <destination-db> <new-schema>

    Where <new-schema> can be one of

      file:<filename>     : a SQL dump of the new schema to use
      db:<database-name>  : the name of a template database whose schema
                            we should use (dump it dynamically)

    This function then opens the original and the destination databases and
    returns a triple of the dbapi module, a connection on the original DB, and a
    connection on the destination DB.

    Important note: you should NOT commit changes that you make to the original
    database.  This will allow you to test your migration procedures multiple
    times without having to recreate the original database from a dump, and
    allows you to test the migration procedure on the actual server (although,
    really, I would not do that without performing a backup beforehand).
    """
    opts, dbfrom, dbto = parse_options()

    # Optionally recreate and restore the source database.
    if opts.recreate_original_db:
        # Recreate the server database
        print ("--- Recreating the current/server/production database '%s'" %
               dbfrom)
        recreate_database_from_dump(dbfrom, user, opts.recreate_original_db)

    # Connect to the source database.
    conn_from = dbapi.connect(host=host,
                              database=dbfrom,
                              user=user)

    # Initialize the destination database
    print ("--- Initializing new/destination/converted database '%s'" %
           dbto)
    ditch_database(dbto, user)

    # Initialize the destination schema.
    if opts.newtyp == 'file':
        print "--- Reading new schema from file"
        try:
            schef = open(opts.newschema, 'r')
        except IOError:
            raise SystemExit("Error: could not open schema file '%s'" %
                             opts.newschema)
    else:
        # Copy from the template database.
        print ("--- Finding new schema by dumping template database '%s'" %
               opts.newschema)
        p, schef = dump_schema(opts.newschema, user)
    print ("--- Initializing destination database '%s' with schema from '%s'" %
           (dbto, opts.newschema))
    call(['psql', '-U', user, '-d', dbto], stdin=schef, stdout=PIPE)
    
    # Connect to destination database.
    conn_to = dbapi.connect(host=host,
                            database=dbto,
                            user=user)

    # Return the DBAPI and connections to the caller.
    return dbapi, conn_from, conn_to


if __name__ == '__main__':
    # For testing.
    migration_init('carpool', 'localhost')


