// sudokuo - program to solve Sudoku puzzles, print solution(s)
// takes one argument, the name of an input file; with no argument reads stdin
// input consists of nine lines, each of nine characters, where filled positions
// are represented by digits 1 to 9 and unfilled positions are represented by
// any other character

#include <fstream>
#include "search.hpp"

const int rowa[] = { 0,0,0, 1,1,1, 2,2,2 };
const int rowb[] = { 3,3,3, 4,4,4, 5,5,5 };
const int rowc[] = { 6,6,6, 7,7,7, 8,8,8 };

const int *square[] = { rowa,rowa,rowa, rowb,rowb,rowb, rowc,rowc,rowc };

class Sudoku : public search::Client {

private:

search::Server<search::PosFixed<243> > searchserver;

int ishape;
int jshape;
int lshape;
int iloc;
int kloc, kloccnt;
int icell;

int celltab[3];

int grid[9][9];
bool pos[9][9][9];
int loctab[243];

int nsol;
int imap[81], jmap[81];

void load(void)
{
	struct { bool row[9], col[9], sqr[9]; } full[9], *f;
	int i, j, k, s, t;

	for (f = full; f < full + 9; f++) for (i = 0; i < 9; i++)
		f->row[i] = f->col[i] = f->sqr[i] = false;
	for (i = 0; i < 9; i++) for (j = 0; j < 9; j++) if (grid[i][j]) {
		f = full + (k = grid[i][j] - 1);
		s = square[i][j];
		f->row[i] = f->col[j] = f->sqr[s] = true;
	}
	for (i = 0; i < 9; i++) for (j = 0; j < 9; j++) {
		s = square[i][j];
		for (k = 0, f = full; k < 9; k++, f++)
			pos[i][j][k] = !(f->row[i] || f->col[j] || f->sqr[s]);
		if (grid[i][j]) k = grid[i][j] - 1, pos[i][j][k] = true;
	}
	t = 0;
	for (i = 0; i < 9; i++) for (j = 0; j < 9; j++)
		imap[t] = i, jmap[t] = j, t++;
}

int getc(std::istream &in)
{
	char c;

	if (!in.good()) return (-1);
	in.get(c);
	return ((int) c);
}

const std::string get(std::istream &in)
{
	int i, j;
	char c;

	for (i = 0; i < 9; i++) {
		for (j = 0; j < 9; j++) {
			c = getc(in);
			if (c == -1 || c == '\n') return ("Short line");
			grid[i][j] = c >= '1' && c <= '9'? c - '0': 0;
		}
		c = getc(in);
		if (c == -1) return ("Unexpected EOF");
		if (c != '\n') return ("Long line");
	}
	return ("");
}

void put(void)
{
	int i, j;

	for (i = 0; i < 9; i++) {
		for (j = 0; j < 9; j++)
			std::cout <<
				(grid[i][j] != 0?
					(char) (grid[i][j] + '0'): '.') << ' ';
		std::cout << std::endl;
	}
}

public:

void run(std::istream &in)
{
	std::string s;

	s = get(in);
	if (s.length()) {
		std::cout << s << std::endl;
	}
	put();
	std::cout << std::endl;
	load();
	searchserver.run(this);
	std::cout << "Number of solutions: " << nsol << std::endl;
	std::cout <<std::endl;
	put();
}

public: // search::Client interface

virtual LocId getnsit(void)
{
	return (243);
}

virtual int getwholeshapes(void)
{
	ishape = 0;
	jshape = 0;
	lshape = 0;
	return (81);
};

virtual int getshapeparts(void)
{
	int l, k, sshape;

	for (l = 0; l < 243; l++) loctab[l] = 0;
	for (k = 0; k < 9; k++) if (pos[ishape][jshape][k])
			loctab[lshape + k] = 1;
	sshape = square[ishape][jshape];
	celltab[0] = ((0*9) + ishape) * 9;
	celltab[1] = ((1*9) + jshape) * 9;
	celltab[2] = ((2*9) + sshape) * 9;
	iloc = 0;
	kloccnt = 0;
	jshape++;
	if (jshape == 9) jshape = 0, ishape++, lshape += 9;
	return (1);
};

virtual int gethedposs(void)
{
	kloc = kloccnt++;
	if (kloccnt == 9) kloccnt = 0;
	return (loctab[iloc++]);
};

virtual int getpossits(void)
{
	icell = 0;
	return (3);
};

virtual LocId getsit(void)
{
	return (celltab[icell++] + kloc);
};

virtual void setsol(void)
{
	nsol = 0;
}

virtual void setsolparts(int nparts)
{
	nsol++;
};

virtual void setsolpart(ShapeId ishape, SitId ihed, SitId isit)
{
	if (nsol != 1) return;
	grid[imap[ishape]][jmap[ishape]] = jmap[ihed] + 1;
};

}; // end of class Sudoku

int main(int argc, char **argv)
{
	Sudoku sudoku;
	std::ifstream infile;

	if (argc > 1) infile.open(argv[1], std::ifstream::in);
	if (argc > 1 &&	!infile.good()) {
		std::cout << "Can't open " << argv[1] << std::endl;
		exit(1);
	}
	sudoku.run(argc > 1? infile: std::cin);
	return (0);
}
