#!/usr/bin/python
# vim: tabstop=4 expandtab shiftwidth=4 softtabstop=4

import ldap
import ldif
import ldap.modlist
import ConfigParser
import os
import sys

import dateutil.parser
import dateutil.tz
import datetime

def notify_created(dn):
    print "notify_created",dn

def notify_modified(dn):
    print "notify_modified",dn

def notify_deleted(dn):
    print "notify_deleted",dn

def readLDIFSource(path):
    with open(path,'r') as f:
        parser = ldif.LDIFRecordList(f)
        parser.parse()
        result = parser.all_records
    return result

def readLdapSource(server,binddn,bindpw,basedn,filter,starttls=False):
    con = ldap.open(server,port=389)
    if starttls:
    	con.start_tls_s()
    con.simple_bind_s(binddn,bindpw)
    results=con.search_s(basedn,ldap.SCOPE_SUBTREE,filter,None)
    return results

def syncLdapDestination(searchresult,destserver,destbinddn,destbindpw,srcbasedn,destbasedn,destrdn,delete=True,starttls=False,updateonly=False,attrfilter=None,exclude=None):

    attrmap=ldap.cidict.cidict({
        })
    classmap={ 
        }

    junk_attrs = [ "memberof", "modifiersname", "modifytimestamp", "entryuuid", "entrycsn", "contextcsn", "creatorsname", "createtimestamp", "structuralobjectclass", "pwdchangedtime", "pwdfailuretime" ]
    update_objects=[]

    if len(searchresult)==0:
	print "empty source, aborting"
	return

    for r in searchresult:
        dn=r[0]

        d=ldap.cidict.cidict(r[1])
        objectclasses=d["objectclass"]

        newObjectclasses=[]
        for o in objectclasses:
            if o.lower() in classmap:
                    new_oc = classmap[o.lower()]
                    if not new_oc in newObjectclasses:
                        newObjectclasses.append(new_oc)
            else:
                #pass
                if not o in newObjectclasses:
                    newObjectclasses.append(o)

        d["objectclass"]=newObjectclasses

      	rpath = dn[:-len(srcbasedn)]
	# print "dn:",dn,"src:",srcbasedn,"rpath:",rpath,"dest:",destbasedn

        for a in d.keys():
            attr=a
            if attrmap.has_key(a.lower()):
                attr=attrmap[attr].lower()
                if attr.lower()!=a.lower():
                    # print "# ",a," -> ",attr
                    values=d[a]
                    del d[a]
                    d[attr]=values
            else:
                # del d[a]
                continue

        dn=rpath+destbasedn

        update_objects.append((dn,d))

    con = ldap.open(destserver,port=389)
    if starttls:
    	con.start_tls_s()
    con.simple_bind_s(destbinddn,destbindpw)

    exist=0
    failed=0
    good=0
    verbose=False
    deleted=0
    existing=[]
    tzutc = dateutil.tz.gettz('UTC')
    now = datetime.datetime.now(tzutc)
    max_age = datetime.timedelta(days=pwd_max_days)

    for o in update_objects:
        dn,entry=o
        try:
            result=con.search_s(dn,ldap.SCOPE_BASE,"objectclass=*")
            destDn,destEntry=result[0]

            if exclude!=None and destDn.lower().endswith(exclude):
                continue

            # hack for syncing accounts locked by password policy
            do_unlock = False
            if pwd_max_days>0 and entry.has_key('pwdChangedTime'):
                # print "pwdChangedTime set for",dn
                pwdChange = entry['pwdChangedTime'][0]
               	d = dateutil.parser.parse(pwdChange)
                if (now-d)>max_age:
                    if dn.startswith('cn=haydar aldetest'):
        	        entry['pwdAccountLockedTime']=[ '000001010000Z' ]
                        print "locking",dn,pwdChange
                else:
                    result = con.search_s(dn,ldap.SCOPE_BASE,"objectclass=*", \
                       attrlist = [ 'pwdAccountLockedTime' ])
                    tmp_dn, tmp_entry = result[0]
                    if tmp_entry.has_key('pwdAccountLockedTime'):
                        print "unlocking",dn,pwdChange
                        do_unlock = True	
 
            mod_attrs=ldap.modlist.modifyModlist(destEntry,entry)

            # hack for unlocking, see above
            if do_unlock:
                mod_attrs.append( (ldap.MOD_DELETE,'pwdAccountLockedTime',None) )

            if attrfilter!=None:
                mod_attrs=[ a for a in mod_attrs if a[1] in attrfilter]

            if junk_attrs!=None:
                mod_attrs=[ a for a in mod_attrs if a[1].lower() not in junk_attrs]

	    if mod_attrs!=[]:
            	exist=exist+1
            	#if verbose:
                #	print dn, "already exists"
		try:
			# print dn,destEntry['objectClass'],entry['objectClass']
	    		con.modify_s(dn,mod_attrs)
		except:
			print "error",dn,mod_attrs
		notify_modified(dn)
            else:
                pass
                # print "no changes, not modified"

        except ldap.NO_SUCH_OBJECT:
            if updateonly==True:
                continue

            try:
                con.add_s(dn,ldap.modlist.addModlist(entry,junk_attrs))
		notify_created(dn)
                if verbose:
                    print dn,"created"
                good=good+1
            except (ldap.OBJECT_CLASS_VIOLATION,ldap.NO_SUCH_OBJECT):
                print dn, "failed"
                failed=failed+1

    if delete==True and updateonly==False:
        result=con.search_s(destbasedn,ldap.SCOPE_SUBTREE,filter)
        existing=[ x[0].lower() for x in result ]

        morituri=existing

        if destbasedn.lower() in existing:
            morituri.remove(destbasedn.lower())

        for o in update_objects:
            dn,entry=o
            if dn.lower() in existing:
                morituri.remove(dn.lower())
        for dn in morituri:
            if exclude != None and dn.lower().endswith(exclude):
                # print "ignoring",dn
                continue

            try:
                con.delete_s(dn)
            except:
                print "failed to delete",dn

            notify_deleted(dn) 
            if verbose:
                print dn,"deleted"
            deleted=deleted+1
 
    con.unbind()
    print good,"entries created,",exist,"updated,",deleted,"deleted,",failed,"failed."


if __name__ == "__main__":
    conffile="ldapsync.conf"
    filter = None
    exclude = None
    if len(sys.argv)>1:
        conffile=sys.argv[1]

    config=ConfigParser.ConfigParser()
    config.read(conffile)

    srcfile = None
    try:
        srcfile = config.get("source","file")
    except:
        pass

    basedn = config.get("source","baseDn")

    if srcfile==None:
        srv = config.get("source","server") 
        admindn = config.get("source","bindDn") 
        adminpw = config.get("source","bindPassword")
        filter = config.get("source","filter")
        starttls = config.getboolean("source","starttls")

    if filter==None:
        filter = '(objectClass=*)'

    try:
        exclude = config.get("destination","excludesubtree").lower()
    except:
        pass

    destsrv = config.get("destination","server") 
    destadmindn = config.get("destination","bindDn") 
    destadminpw = config.get("destination","bindPassword")
    destbasedn = config.get("destination","baseDn")
    destdelete = config.getboolean("destination","delete")
    rdn = config.get("destination","rdn")

    try:
        updateonly = not config.getboolean("destination","create")
    except:
        updateonly = False
    deststarttls = config.getboolean("destination","starttls")
    try:
        attrfilter = config.get("destination","attributes").split(",")
    except:
        attrfilter = None

    if srcfile:
        result = readLDIFSource(srcfile)
    else:
        result = readLdapSource(srv,admindn,adminpw,basedn,filter,starttls)

    try:
        pwd_max_days = int(config.get("source","pwd_max_days"))
    except:
        pwd_max_days = 0

    syncLdapDestination(result,destsrv,destadmindn,destadminpw,basedn,destbasedn,rdn,destdelete,deststarttls,updateonly,attrfilter,exclude)
