#!/usr/bin/env python
#
# Copyright (C) 2013 EPITA Research and Development Laboratory (LRDE)
#
# This script is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, version 2 of the License.
#
# It is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
# or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public
# License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Olena.  If not, see <http://www.gnu.org/licenses/>.

import sys
sys.path.append("./lib/html")
import HTML
import os
import subprocess
import shlex
import shutil
import argparse
import logging
import glob
import array

class Config:
    def __init__(self):
        self.bench_version = "1.0"
        self.bench_dir = os.getcwd() + os.sep
        self.gt_dir = self.bench_dir + os.sep + "gt"
        self.gt_bin_dir = self.gt_dir + os.sep + "bin"
        self.gt_ocr_dir = self.gt_dir + os.sep + "ocr"
        self.input_dir = self.bench_dir + os.sep + "input"
        self.output_dir = self.bench_dir + os.sep + "output"
        self.bin_dir = self.bench_dir + os.sep + "bin"
        self.bin_conf_dir = self.bench_dir + os.sep + "bin.conf"
        self.result_dir = self.bench_dir + os.sep + "result"

    supported_impl = [ "sauvola", "sauvola_msk", "sauvola_mskx", "wolf", "otsu", "niblack", "kim" ]
    supported_text_size = [ "small", "medium", "large" ];
    supported_quality = [ "clean", "scanned", "orig" ];
    tesseract_options = "-l fra -psm 7"


#------------------------------------
def setup_logger():
    # Make a global logging object.
    log = logging.getLogger("logfun")
    log.setLevel(logging.DEBUG)

    # This handler writes everything to a file.
    h1 = logging.FileHandler("./bench.log", 'w')
    f = logging.Formatter("%(message)s")
    h1.setFormatter(f)
    h1.setLevel(logging.DEBUG)
    log.addHandler(h1)

    # Also output on stdout
    h2 = logging.StreamHandler(sys.stdout)
    log.addHandler(h2)

    return log

# The logger is global.
log = setup_logger()


#------------------------------------
def compute_binarization(conf, quality):
    log.debug("* Computing binarization for " + quality + " documents.")

    # Check if output directories exist
    for impl in conf.use_implementations:
        if not os.path.exists(conf.output_dir + os.sep + impl):
            os.makedirs(conf.output_dir + os.sep + impl)

    # Iterate over implementations and compute output if needed for
    # each input file.
    all_files_ok = True
    for impl in conf.use_implementations:
        log.debug("* Binarizing with " + impl)

        # Read implementation configuration
        try:
            bconf = open(conf.bin_conf_dir + os.sep + impl + ".conf", 'r')
            cmd = bconf.readline();

            for file in os.listdir(conf.input_dir):
                if file.endswith("." + quality + ".png"):
                    out = conf.output_dir + os.sep + impl + os.sep + os.path.basename(file)
                    if not os.path.exists(out) or conf.force_regen_output:
                        if cmd:
                            log.debug("  * Binarizing " + file)
                            subprocess.call(shlex.split(conf.bin_dir + os.sep + cmd + " " + conf.input_dir + os.sep + file + " " + out))

                        # Available results may have been computed on a different
                        # system and the binary may not be available for
                        # re-computing results. We must warn the user in case a
                        # binarization is missing and cannot be re-computed.
                        else:
                            msg = "Cannot compute binarization for " + file + " with impl " + impl + " or result not available!\n"
                            log.warning(msg)
                            all_files_ok = False
                            fd = open(conf.bench_dir + os.sep + "bin_computation_warnings.log", 'a')
                            fd.write(msg)
                            fd.close()

        except subprocess.CalledProcessError:
            log.exception("Cannot compute binarization for " + file)
            exit(1)
        except IOError, OSError:
            log.exception("Cannot open configuration file " + conf.bin_conf_dir + os.sep + impl + ".conf for " + impl)

    if not all_files_ok:
        log.warning("WARNING: All the binarizations could not be computed. "
                    "Further processing may produce scores which are not comparable.")
    else:
        log.debug("Binarization for " + quality + " quality done.")


#------------------------------------
def compute_lines(conf):
    log.debug("* Computing text lines from binarization outputs. (it might take some time...)")

    # Iterate over implementations and compute output if needed for
    # each binarization file.  Binarization output from full documents
    # is supposed to be computed already.
    try:
        for impl in conf.use_implementations:
            log.debug("Computing lines from binarization results of " + impl)
            line_base_dir = conf.output_dir + os.sep + impl + os.sep + "lines";

            if (not os.path.exists(line_base_dir + os.sep + "lines.done")
                or conf.force_regen_output):

                for text_size in conf.supported_text_size:
                    if not os.path.exists(line_base_dir + os.sep + text_size):
                        os.makedirs(line_base_dir + os.sep + text_size)

                        #FIXME: change output names from *PDF.orig.png30.png to *PDF30.orig.png to help matching gt which have names of form *PDF30.txt
                subprocess.call(shlex.split(conf.bin_dir + "/line_maker --extract " + conf.output_dir + os.sep + impl + " " + conf.input_dir + os.sep + "bboxes " + conf.output_dir + os.sep + impl + os.sep + "lines"), stdout=open(os.devnull, "w"), stderr=open(os.devnull, "w"))

            open(line_base_dir + os.sep + "/lines.done", "w+")

    except subprocess.CalledProcessError:
            log.exception("Cannot compute lines for " + impl)
            exit(1)

#------------------------------------
def pixel_based_comparison(conf, impl, file):
    try:
        # Run comparison and get the text output.
        abs_path_gt = conf.gt_bin_dir + os.sep + file
        abs_path_output = conf.output_dir + os.sep + impl + os.sep + file
        p = subprocess.Popen(shlex.split(conf.bin_dir + os.sep + "eval_gt " + abs_path_output + " " + abs_path_gt),
                             stdout=subprocess.PIPE)
        out, err = p.communicate()
        res = out.split(' ')

        #Retrieve Precision, Recall and F-Measure in output.
        precision = res[7]
        recall = res[9]
        fm = res[11]

    except subprocess.CalledProcessError:
        log.exception("There was an error while processing " + file
                      + ". The expected output file may not have been generated.")
        raise

    return precision, recall, fm


#------------------------------------
def generate_html_output(conf, scores, html, headers):
        # HTML output
        html_list = HTML.Table(header_row=headers)

        with open(scores) as f:
            content = f.readlines()
            for line in content:
                html_list.rows.append(line.split(';'))
        htmlcode = str(html_list)

        # Write data
        f = open(html, 'w')
        f.write(htmlcode)
        f.close()


#------------------------------------
def bin_evaluation(conf):
    log.debug("* Performing pixel-based evaluation...")

    # Create results directory if needed
    if not os.path.exists(conf.result_dir):
        os.makedirs(conf.result_dir)

    # Final output file.
    ff = open(conf.result_dir + os.sep + "bin_evaluation_summary.csv", 'w')
    #ff.write("Implementation;Precision;Recall;FM\n")

    for impl in conf.use_implementations:
        log.debug("  Processing results with " + impl)
        f = open(conf.result_dir + os.sep + "bin_evaluation_" + impl + ".csv", 'w')

        validfiles = 0
        stats = [ 0.0, 0.0, 0.0 ] # precision, recall, fm
        for gt in os.listdir(conf.gt_bin_dir):
            if gt.endswith(".clean.png"):
                log.debug("    - " + gt)
                page_id=gt.split("_")[0]
                try:
                    precision, recall, fm = pixel_based_comparison(conf, impl, gt)
                    f.write(page_id + ";" + precision + ";" + recall + ";" + fm)
                    stats[0] += float(precision)
                    stats[1] += float(recall)
                    stats[2] += float(fm)
                    validfiles += 1
                except subprocess.CalledProcessError:
                    log.warning("No binarization found for file " + gt + " and " + impl)
                    fd = open(conf.bench_dir + os.sep + "bin_evaluation_error.log", 'a')
                    fd.write("No binarization found for file " + gt + " and " + impl)
                    fd.close()
                    # Ignoring file.
                    continue

        # Taking mean of each stat data in order to compute the total score for the method.
        ff.write(impl + ";" + str(stats[0] / float(validfiles)) + ";" +  str(stats[1] / float(validfiles)) + ";" + str(stats[2] / float(validfiles)) + "\n")

        f.close()

        # HTML output
        generate_html_output(conf,
                             conf.result_dir + os.sep + "bin_evaluation_" + impl + ".csv",
                             conf.result_dir + os.sep + "bin_evaluation_" + impl + ".html",
                             ["Method", "Precision", "Recall", "FM"])


    ff.close()

    # HTML output
    generate_html_output(conf,
                         conf.result_dir + os.sep + "bin_evaluation_summary.csv",
                         conf.result_dir + os.sep + "bin_evaluation_summary.html",
                         ["Method", "Precision", "Recall", "FM"])


#------------------------------------
def ocr_evaluation(conf):
    log.debug("* Performing OCR-based evaluation...")


    # Compute OCR output
    for text_size in conf.use_text_size:
        for quality in conf.use_quality:
            log.debug("Computing OCR for quality " + quality + " and " + text_size + " text with")

            for impl in conf.use_implementations:
                log.debug("    * " + impl)
                line_base_dir = conf.output_dir + os.sep + impl + os.sep + "lines";
                ocr_base_dir = conf.output_dir + os.sep + impl + os.sep + "ocr";
                outdir = ocr_base_dir + os.sep + text_size

                # Prepare output directories.
                if os.path.exists(outdir):
                    filelist = glob.glob("*." + quality + ".png.txt")
                    for f in filelist:
                        os.remove(f)
                else:
                    os.makedirs(outdir)

                fdir = line_base_dir + os.sep + text_size
                devnull = open('/dev/null', 'w')
                count = len(glob.glob(fdir + os.sep + "*" + quality + ".png"))
                processed = 1

                outdir = ocr_base_dir + os.sep + text_size + os.sep

                for file in os.listdir(fdir):
                    if quality + ".png" in file:
                        status = r"Processing " + str(processed) + "/" + str(count)
                        status = status + chr(8)*(len(status)+1)
                        print status,
                        try:
                            subprocess.call(shlex.split("tesseract " + conf.tesseract_options + " " + fdir + os.sep + file + " " + outdir + file), stdout=devnull, stderr=devnull)
                        except subprocess.CalledProcessError:
                            log.exception("Cannot compute OCR output for " + impl + " with file " + file)
                            exit(1)

                        processed = processed + 1
                # FIXME: clear display...
                print  str(' '*80) + "\r",


    # Compute score for each file.
    for text_size in conf.use_text_size:
        for quality in conf.use_quality:

            # Prepare output file.
            of = open(conf.result_dir + os.sep + "ocr_evaluation_per_file_" + text_size + "_" + quality + ".csv", 'w')

            # Write CSV headers
            headers= "file"
            for impl in conf.use_implementations:
                headers = headers + ";" + impl
            of.write(headers + "\n")

            # Compute edit distance for each text line (if available)
            fdir = conf.gt_ocr_dir + os.sep + text_size
            for file_gt in os.listdir(fdir):
                result = file_gt.split("_")[0] + "." + file_gt.split(".")[1]
                for impl in conf.use_implementations:
                    try:
                        ocr_base_dir = conf.output_dir + os.sep + impl + os.sep + "ocr";
                        out_file = ocr_base_dir + os.sep + text_size + os.sep + os.path.splitext(file_gt)[0] + "." + quality + ".png.txt"
                        if os.path.exists(out_file):
                            # Compute edit distance
                            p = subprocess.Popen(shlex.split(conf.bin_dir + os.sep + "edit_dist " + out_file + " " + fdir + os.sep + file_gt), stdout=subprocess.PIPE)
                            out, err = p.communicate()
                        else:
                            log.warning(out_file + " is not available\n");
                    except subprocess.CalledProcessError:
                        log.exception("Cannot compare files " + out_file + " and " + file_gt)
                        exit(1)

                    result = result + ";" + out.rstrip('\n')
                of.write(result + "\n")
            of.close()


    # Compute total score
    #
    for text_size in conf.use_text_size:
        for quality in conf.use_quality:

            # Read output filee.
            f = open(conf.result_dir + os.sep + "ocr_evaluation_per_file_" + text_size + "_" + quality + ".csv", 'r')
            of = open(conf.result_dir + os.sep + "ocr_evaluation_" + text_size + "_" + quality + ".csv", 'w')

            # Array accumulating errors.
            headers = []
            accu = array.array('I')
            total_char_count = array.array('I')

            result_per_file = f.readlines()
            for line in result_per_file:
                value = line.split(';')

                # FIXME: Ugly.  Skip the header line and adjust the
                # accumulator array.
                if value[0] == "file":
                    for i in range(1, len(value)):
                        headers.append(value[i].rstrip('\n'))
                        accu.append(0)
                        total_char_count.append(0)

                    continue;

                # Accumulate results
                for i in range(1, len(value)):
                    accu[i - 1] = accu[i - 1] + long(value[i])

                    # count the number of character in GT.
                    gt_ref = value[0].split('.')
                    with open(conf.gt_ocr_dir + os.sep + text_size + os.sep + gt_ref[0] + "_nouvel-obs_hbhnr300_constructedPdf_Nouvelobs2402PDF." + gt_ref[1] + ".txt", 'r') as f:
                        # If an output exists...
                        if (os.path.exists(conf.output_dir + os.sep + headers[i - 1] + os.sep + "ocr" + os.sep + text_size + os.sep + gt_ref[0] + "_nouvel-obs_hbhnr300_constructedPdf_Nouvelobs2402PDF." + gt_ref[1] + "." + quality + ".png.txt")):
                            for line in f:
                                total_char_count[i - 1] += len(line)



            # For each implementation
            for i in range(0, len(accu)):
                summary = headers[i] + ": " + str(float(accu[i] * 100) / float(total_char_count[i]))
                print headers[i] + ": " + str(float(accu[i])) +  "* 100 / " + str(float(total_char_count[i]))
                of.write(summary + '\n')

            # HTML output
            generate_html_output(conf,
                                 conf.result_dir + os.sep + "ocr_evaluation_" + text_size + "_" + quality + ".csv",
                                 conf.result_dir + os.sep + "ocr_evaluation_" + text_size + "_" + quality + "_summary.html",
                                 ["Method", "OCR error"])

            f.close()
            of.close()

    # HTML OUTPUT (Summary)
    generate_full_ocr_result_html_output(conf, conf.result_dir + os.sep + "ocr_evaluation_summary.html")

    log.debug("OCR results computed.")


#------------------------------------
def generate_full_ocr_result_html_output(conf, html):
    ntextsize = len(conf.use_text_size)
    nquality = len(conf.use_quality)
    ncols = nquality * ntextsize

    htmlcode  = "<table border=\"1\">"
    htmlcode += "<tr>"
    htmlcode += "<td align=\"center\"><strong>Method</strong></td>"
    htmlcode += "<td colspan=\"" + str(ncols) + "\" align=\"center\"><strong>OCR error (%)</strong</td>"
    htmlcode += "</tr>"
    htmlcode += "<tr>"
    htmlcode += "<td align=\"right\" style=\"font-size:8pt\"><strong>Set &rarr;</strong></td>"

    for quality in conf.use_quality:
        htmlcode += "<td colspan=\"" + str(ntextsize) + "\"><strong>" + quality + " documents</strong</td>"

    htmlcode += "</tr>"
    htmlcode += "<tr>"
    htmlcode += "<td align=\"right\" style=\"font-size:8pt\"><strong>Subset &rarr;</strong></td>"

    for quality in conf.use_quality:
        if "small" in conf.use_text_size:
            htmlcode += "<td><strong>S</strong</td>"
        if "medium" in conf.use_text_size:
            htmlcode += "<td><strong>M</strong</td>"
        if "large" in conf.use_text_size:
            htmlcode += "<td><strong>L</strong</td>"

    htmlcode += "</tr>"

    results = [[0.0 for x in xrange(ncols + 1)] for x in xrange(len(conf.use_implementations))]

    for quality_id in range(0, len(conf.use_quality)):
        for text_size_id in range(0, len(conf.use_text_size)):
            f = open(conf.result_dir + os.sep + "ocr_evaluation_" + conf.use_text_size[text_size_id] + "_" + conf.use_quality[quality_id] + ".csv")
            result_per_file = f.readlines()
            for line_id in range(0,len(result_per_file)):
                value = result_per_file[line_id].split(': ')
                results[line_id][0] = value[0]
                results[line_id][quality_id * ntextsize + text_size_id + 1] = float(value[1])
            f.close()

    for impl_id in range(0,len(conf.use_implementations)):
        htmlcode += "<tr>"
        htmlcode += "<td><strong>" + results[impl_id][0] + "</strong></td>"
        for res_id in range(1,len(results[impl_id])):
            htmlcode += "<td>" + str(round(results[impl_id][res_id],2)) + "</td>"
        htmlcode += "</tr>"


        # Write data
    f = open(html, 'w')
    f.write(htmlcode)
    f.close()



#------------------------------------
def parse_options(conf):
    parser = argparse.ArgumentParser(description='Benchmark tool for LRDE Document Binarization Dataset.',
                                     epilog='Copyright (C) 2013 EPITA Research and Development Laboratory (LRDE) http://olena.lrde.epita.fr | Contact: olena@lrde.epita.fr',
                                     add_help=True)

    # Type of evaluation
    parser.add_argument('--disable-bin-eval', action='store_true',
                        help='Disable the pixel-based evaluation which compares raw binarization outputs with a groundtruth.')
    parser.add_argument('--disable-ocr-eval', action='store_true',
                       help='Disable the OCR-based evaluation which compares how OCR perform on binarization outputs.')

    # Data generation
    parser.add_argument('--force-regen-output', action='store_true',
                       help='Always compute binarization even if a result already exists.')

    # Modifiers
    parser.add_argument('--use-implementations', metavar='\"impl1 impl2 impl3\"',
                       help='Restrict the benchmark to specific implementations.')
    parser.set_defaults(use_implementations='sauvola sauvola_msk sauvola_mskx wolf otsu niblack kim')

    parser.add_argument('--use-text-size', metavar='\"text_size1 text_size2\"',
                       help='Restrict the benchmark to specific text size.')
    parser.set_defaults(use_text_size='small medium large')

    parser.add_argument('--use-quality', metavar='\"quality1 quality2\"',
                       help='Restrict the benchmark to specific quality.')
    parser.set_defaults(use_quality='clean scanned orig')


    # infos
    parser.add_argument('--list-implementations', action='store_true',
                       help='List default available implementations.')

    parser.add_argument('--version', action='version', version='%(prog)s ' + conf.bench_version + " - Copyright LRDE 2013")
    parser.parse_args(namespace = conf)


#------------------------------------
def handle_options(conf):
    if conf.list_implementations:
        print "Supported implementations are:",
        for impl in conf.supported_impl:
            print impl,
        exit(0)

    # Nothing special, about to start.
    log.debug("Starting benchmark with the following implementation: " + conf.use_implementations)
    conf.use_implementations = conf.use_implementations.split(' ')
    conf.use_text_size = conf.use_text_size.split(' ')
    conf.use_quality = conf.use_quality.split(' ')



# FIXME:
# - html output
# - do not generate lines everytime
# - ocr evaluation
# - readme : how to use, where to find the results, how to test our own impl, licence
#------------------------------------
def main():

    # Parse options
    conf = Config()
    parse_options(conf)
    handle_options(conf)

    # Prepare common output data
    compute_binarization(conf, "clean")

    if not conf.disable_bin_eval:
        bin_evaluation(conf)


    if not conf.disable_ocr_eval:
        compute_binarization(conf, "orig")
        compute_binarization(conf, "scanned")
        compute_lines(conf)
        ocr_evaluation(conf)


main()
