/*
	This utility is used to unify (using hard links) two or more
	virtual servers.
	It compares the each vserver with the first one and for every
	common package (RPM, same version), it does a hard link on non
	configuration file. It turns the file immutable after that.
*/
#include <stdio.h>
#include <fcntl.h>
#include <sys/ioctl.h>
#include <sys/stat.h>
#include <unistd.h>
#include <errno.h>

#include <string>
#include <vector>
#include <list>
#include <algorithm>
#include <iostream.h>
#include <pfstream.h>
#include <linux/ext2_fs.h>
#include "vutil.h"

// Patch to help compile this utility on unpatched kernel source
#ifndef EXT2_IMMUTABLE_FILE_FL
	#define EXT2_IMMUTABLE_FILE_FL	0x00000010
	#define EXT2_IMMUTABLE_LINK_FL	0x00008000
#endif


using namespace std;

static bool testmode = false;
static bool undo = false;
static bool debug = false;
static int  ext2flags = EXT2_IMMUTABLE_FILE_FL | EXT2_IMMUTABLE_LINK_FL;



static void usage()
{
	cerr <<
		"vunify version " << VERSION <<
		"\n\n"
		"vunify [ options ] reference-server vservers ... -- packages\n"
		"\n"
		"--test: Show what will be done, do not do it.\n"
		"--undo: Put back the file in place, using copies from the\n"
		"        reference server.\n"
		"--debug: Prints some debugging messages.\n"
		"--noflags: Do not put any immutable flags on the file\n"
		"--immutable: Set the immutable_file bit on the files.\n"
		"--immutable-mayunlink: Sets the immutable_link flag on files.\n"
		"\n"
		"By default, the immutable_file and	immutable_link flags are\n"
		"set on the files. So if you want no immutable flags, you must\n"
		"use --noflags. If you want a single flag, you must use\n"
		"--noflags first, then the --immutable or --immutable-mayunlink\n"
		"flag.\n"
		;
}

class PACKAGE{
public:
	string name;
	string version;	// version + release
	list<string> files;		// Files to unify
							// This is loaded on demand
	PACKAGE(string &_name, string &_version)
		: name (_name), version(_version)
	{
	}
	PACKAGE(const char *_name, const char *_version)
		: name (_name), version(_version)
	{
	}
	PACKAGE(const string &line)
	{
		*this = line;
	}
	PACKAGE & operator = (const string &_line)
	{
		string line (_line);
		string::iterator pos = find (line.begin(),line.end(),'=');
		if (pos != line.end()){
			name = string(line.begin(),pos);
			version = string(pos + 1,line.end());
		}
		return *this;
	}
	bool operator == (const PACKAGE &v) const
	{
		return name == v.name && version == v.version;
	}
	// Load the file member of the package, but exclude configuration file
	void loadfiles(const string &ref)
	{
		if (files.empty()){
			if (debug) cout << "Loading files for package " << name << endl;
			string cmd = "| /usr/sbin/chroot /vservers/" + ref + " /bin/rpm -ql --dump "
				+ name + "-" + version;
			ipfstream oo (cmd.c_str());
			while (1){
				char tmp[1000];
				int mode=-1,type=-1;
				oo.getline (tmp,sizeof(tmp)-1);
				if (tmp[0] == '\0') break;
				char *start = tmp;
				for (int i=0; i<8; i++){
					char *pt = start;
					while (*pt > ' ') pt++;
					if (*pt == ' ') *pt++ = '\0';
					if (i == 4){
						sscanf(start,"%o",&mode);
					}else if (i==7){
						type = atoi(start);
					}
					start = pt;
						
				}						
				if (type == 0 && S_ISREG(mode)){
					files.push_front (tmp);
				}else if (debug){
					cout << "Package " << name << " exclude " << tmp << endl;
				}
			}
			if (debug) cout << "Done\n";
		}
	}
};


static ostream & operator << (ostream &c, const PACKAGE &p)
{
	return c << p.name << "-" << p.version;
}

template<class T>
	void printit(T a){
		cout << "xx " << a << endl;
	}

template<class T>
	class printer{
		string title;
		public:
		printer (const char *_title): title(_title){}
		bool operator()(T a){
			cout << title << " " << a << endl;
		}
	};


/*
	Load the list of all packages in a vserver
*/
static void vunify_loadallpkg (string &refserver, list<PACKAGE> &packages)
{
	string cmd = "| /usr/sbin/chroot /vservers/" + refserver + " /bin/rpm -qa"
		+ " --queryformat \"%{name}=%{version}-%{release}\\n\"";
	// cout << "command " << cmd << endl;
	ipfstream oo (cmd.c_str());
	copy (istream_iterator<string>(oo),istream_iterator<string>()
		,inserter(packages,packages.begin()));
}

/*
	Set the immutable flag on a file
*/
static int setext2flag (const char *fname, bool set)
{
	int ret = -1;
	int fd = open (fname,O_RDONLY);
	if (fd == -1){
		cerr << "Can't open file " << fname 
			<< " (" << strerror(errno) << ")\n";
	}else{
		int flags = set ? ext2flags : 0;
		ret = ioctl (fd,EXT2_IOC_SETFLAGS,&flags);
		close (fd);
		if (ret == -1){
			cerr << "Can't " << (set ? "set" : "unset")
				<< " immutable flag on file "
				<< fname
				<< " (" << strerror(errno) << ")\n";
		}
	}
	return ret;
}

/*
	Object to unify a file
	The file is first removed, then a hard link is made  and then
	the immutable flag is done
*/
class file_unifier{
	string &ref_server,&target_server;
	int &ret;
	public:
	file_unifier(string &_ref, string &_target, int &_ret)
		: ref_server(_ref),target_server(_target), ret(_ret)
	{}
	void operator()(const string &file)
	{
		string refpath = "/vservers/" + ref_server + file;
		string dstpath = "/vservers/" + target_server + file;
		if (debug) cout << "Unify " << refpath << " -> " << dstpath << endl;
		struct stat st;
		if (stat(refpath.c_str(),&st)==-1){
			if (debug) cout << "File " << refpath << " does not exist, ignored\n";
		}else if (setext2flag(refpath.c_str(),false)==-1){
			ret = -1;
		}else if (unlink(dstpath.c_str())==-1){
			ret = -1;
			cerr << "Can't delete file " << dstpath
				<< " (" << strerror(errno) << ")\n";
		}else{
			if (undo){
				if (file_copy(refpath.c_str(),dstpath.c_str(),st)==-1){
					ret = -1;
					cerr << "Can't copy file " << refpath << " to " << dstpath
						<< " (" << strerror(errno) << ")\n";
				}
			}else{
				if (link(refpath.c_str(),dstpath.c_str())==-1){
					ret = -1;
					cerr << "Can't link file " << refpath << " to " << dstpath
						<< " (" << strerror(errno) << ")\n";
				}
			}
			// We put back the original immutable because other vservers
			// may be unified on it.
			if (setext2flag(refpath.c_str(),true)==-1){
				ret = -1;
			}
		}
	}
};
// CHeck if two package have the same name (but potentially different version)
class same_name{
	PACKAGE &pkg;
public:
	same_name(PACKAGE &_pkg) : pkg(_pkg) {}
	bool operator()(const PACKAGE &p)
	{
		return pkg.name == p.name;
	}
};
// Predicate to decide if a package must be unified
class package_unifier{
public:
	string &ref_server,&target_server;
	list<PACKAGE> &target_packages;
	int &ret;
	package_unifier(string &_ref,
			string &_target,
			list<PACKAGE> &_target_packages,
			int &_ret)
		: ref_server(_ref),target_server(_target)
		, target_packages(_target_packages) , ret(_ret)
	{}
	void operator()(PACKAGE &pkg)
	{
		if (find(target_packages.begin(),target_packages.end(),pkg)
			!=target_packages.end()){
			// Ok, the package is also in the target vserver
			cout << "Unify pkg " << pkg << " from " << ref_server << " to "
				<< target_server << endl;

			if (!testmode){
				pkg.loadfiles(ref_server);
				for_each (pkg.files.begin(),pkg.files.end()
					,file_unifier(ref_server,target_server,ret));
			}
		}else if (testmode){
			// The package is missing, in test mode we provide more information
			if (find_if(target_packages.begin(),target_packages.end(),same_name(pkg))
				!=target_packages.end()){
				cout << pkg << " exist in server " << target_server << " not unified\n";
			}else{
				cout << pkg << " does not exist in server " << target_server << endl;
			}
		}
	}
};

// For each vserver, find the common packages and unify them
class server_unifier{
public:
	list<PACKAGE> &ref_packages;
	string &ref_server;
	int &ret;
	server_unifier(string _ref_server, list<PACKAGE> &_packages, int &_ret)
		: ref_packages(_packages),ref_server(_ref_server), ret(_ret)
		{}
	void operator()(string serv)
	{
		list<PACKAGE> pkgs;
		vunify_loadallpkg (serv,pkgs);
		for_each(ref_packages.begin(),ref_packages.end()
			,package_unifier(ref_server,serv,pkgs,ret));
	}
};
class deleteif{
public:
	char **argv0,**argvn;
	deleteif(char **_argv0, char **_argvn): argv0(_argv0),argvn(_argvn){}
	bool operator()(const PACKAGE &pkg)
	{
		bool found = false;
		for (char **pt = argv0; pt < argvn; pt++){
			if (pkg.name == *pt){
				found = true;
				break;
			}
		}
		return !found;
	}
};

int main (int argc, char *argv[])
{
	int ret = -1;
	int i;
	for (i=1; i<argc; i++){
		const char *arg = argv[i];
		//const char *opt = argv[i+1];
		if (strcmp(arg,"--test")==0){
			testmode = true;
		}else if (strcmp(arg,"--undo")==0){
			undo = true;
		}else if (strcmp(arg,"--debug")==0){
			debug = true;
		}else if (strcmp(arg,"--noflags")==0){
			ext2flags = 0;
		}else if (strcmp(arg,"--immutable")==0){
			ext2flags |= EXT2_IMMUTABLE_FILE_FL;
		}else if (strcmp(arg,"--immutable-mayunlink")==0){
			ext2flags |= EXT2_IMMUTABLE_LINK_FL;
		}else{
			break;
		}
	}
	if (i==argc){
		usage();
	}else{
		string refserv = argv[i++];
		list<string> vservers;
		for (; i<argc && strcmp(argv[i],"--")!=0; i++){
			vservers.push_front (argv[i]);
		}
		for_each (vservers.begin(),vservers.end(),printer<string>("vservers"));
		if (i == argc || strcmp(argv[i],"--")!=0){
			usage();
		}else{
			i++;
			if (i < argc){
				list<PACKAGE> packages;
				vunify_loadallpkg (refserv,packages);
				if (i != argc-1 || strcmp(argv[i],"ALL")!=0){
					// We keep only the packages supplied on the command line
					packages.remove_if(deleteif (argv+i,argv+argc));
				}
				ret = 0;
				umask (0);
				for_each (vservers.begin(),vservers.end(),server_unifier(refserv,packages,ret));
			}else{
				usage();
			}
		}
	}
	return ret;
}




