#!/usr/bin/python
# -*- coding: utf-8 -*-

import argparse
import glob
import os
import sys

from knapsack.parse_file import PACKAGE


def join(path, filename):
    return os.path.join(path, filename if filename[0] != '/' else filename[1:])


def get_files(path):
    """Return the files and directories form a path."""
    files, dirs = set(), set(('/'))
    for (dirpath, dirname, filenames) in os.walk(path):
        dirpath = dirpath[len(path):]
        if dirpath:
            dirs.add(dirpath)
        files.update(os.path.join(dirpath, f) for f in filenames)
    return files, dirs


_CACHE = [None, None]
def find_packages(path, filename_prefix):
    """Use the filename_prefix to find real packages. This assumes the prefices are sorted when called"""
    _filename = join(path, filename_prefix)
    _dirname, _prefixname = os.path.split(_filename)
    if _CACHE[0] == _dirname:
       all_files = _CACHE[1]	
    else:
	_CACHE[0] = _dirname
	try:
  	   all_files = sorted(os.listdir(_dirname))
        except OSError:
	   all_files = []
	_CACHE[1] = all_files
    donewith = []
    found = False
    while len(all_files):
	_basename = all_files.pop(0)
	donewith.append(_basename)
	if not _basename.startswith(_prefixname):
            # we can skip everything if we already had a prefix hit and now are no longer 
            if found: 
                # but be careful - it might be the next prefix
                all_files.insert(0, _basename)
                return
            continue
	found = True
        m = PACKAGE.match(_basename)
        _path = os.path.join(_dirname, m.groups()[0]) if m else _filename
        if _path.endswith(filename_prefix):
            yield os.path.join(_dirname, _basename)[len(path):]
	else:
	    # if it's not matching, the path is longer so we can stop here
	    all_files.insert(0, _basename)
	    return
    if not found:
	_CACHE[1] = donewith


if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Synchronize a rsync directory using hard links and an external file.')
    parser.add_argument('--src', help='Src directory')
    parser.add_argument('--srcalt', help='Src directory')
    parser.add_argument('--dst', help='Dst directory')
    parser.add_argument('kpsol', help='Knapsack solution file')

    args = parser.parse_args()

    if not all((args.src, args.dst, args.kpsol)):
        parser.print_help()
        exit(0)

    args.src = os.path.abspath(args.src)
    args.srcalt = os.path.abspath(args.srcalt)
    src_files_to_expand = sorted(f.strip() for f in open(args.kpsol) if f.strip())
    src_files = set()
    for src_file in src_files_to_expand:
        src_files.update((f for f in find_packages(args.src, src_file)))

    src_dir = set(('/'))
    for src_file in src_files:
        d = os.path.dirname(src_file)
        while d != '/':
            src_dir.add(d)
            d = os.path.dirname(d)

    args.dst = os.path.abspath(args.dst)
    dst_files, dst_dir = get_files(args.dst)

    # print '** SRC DIR'
    # for i in src_dir:
    #     print i
    # print '** SRC FILES'
    # for i in src_files:
    #     print i

    new_files = src_files - dst_files
    del_files = dst_files - src_files

    new_dir = src_dir - dst_dir
    cur_dir = src_dir & dst_dir
    del_dir = dst_dir - src_dir

    # print '** New dirs'
    # for i in new_dir:
    #     print i
    if new_dir:
        print >> sys.stderr, 'Creating new directories ({0})...'.format(len(new_dir))
    for d in sorted(new_dir):
        stat = os.stat(d)
        d = join(args.dst, d)
        os.mkdir(d)
        os.chmod(d, stat.st_mode)

    # Update permissions
    for d in sorted(cur_dir):
        stat = os.stat(d)
        d = join(args.dst, d)
        os.chmod(d, stat.st_mode)

    # print '** New files'
    # for i in new_files:
    #     print i
    if new_files:
        print >> sys.stderr, 'Linking new files ({0})...'.format(len(new_files))
    for f in new_files:
        src, dst = join(args.src, f), join(args.dst, f)
        try:
            os.link(src, dst)
        except:
            print 'SRC', src
            print 'DST', dst
	    src = join(args.srcalt, f)
	    os.link(src, dst)
            

    # print '** Del files'
    # for i in del_files:
    #     print i
    if del_files:
        print >> sys.stderr, 'Unlinking deleted files ({0})...'.format(len(del_files))
    for f in del_files:
        f = join(args.dst, f)
        os.unlink(f)

    # print '** Del dirs'
    # for i in del_dir:
    #     print i
    if del_dir:
        print >> sys.stderr, 'Unlinking deleted directories ({0})...'.format(len(del_dir))
    for d in sorted(del_dir, reverse=True):
        d = join(args.dst, d)
        os.rmdir(d)
