#!/usr/bin/python2
# -*- coding: utf-8; mode: Python; indent-tabs-mode: t; tab-width: 4; python-indent: 4 -*-

# Copyright (C) 2012, 2015  Olga Yakovleva <yakovleva.o.v@gmail.com>

# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.

# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.

# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <https://www.gnu.org/licenses/>.

import argparse
import codecs
import collections
import subprocess
import json
import bisect

class unit_indexer(object):
	def __init__(self):
		self.next_index=0
		self.cache=dict()

	def add(self,unit):
		index=self.cache.get(unit,None)
		if index is None:
			index=self.next_index
			self.next_index+=1
			self.cache[unit]=index
		return index

	def count(self):
		return self.next_index

class sentence(object):
	def __init__(self,text,order):
		self.text=text
		self.order=order
		self.prev=None
		self.next=None
		self.units=[]

class sentence_list(object):
	def __init__(self):
		self.head=None
		self.tail=None

	def append(self,sent):
		if self.head is None:
			self.head=sent
			self.tail=sent
		else:
			sent.prev=self.tail
			self.tail.next=sent
			self.tail=sent

	def remove(self,sent):
		if self.head is sent:
			self.head=sent.next
		if self.tail is sent:
			self.tail=sent.prev
		if sent.next is not None:
			sent.next.prev=sent.prev
		if sent.prev is not None:
			sent.prev.next=sent.next
		sent.prev=None
		sent.next=None
		return sent

	def empty(self):
		return self.head is None

	def first(self):
		return self.head

	def last(self):
		return self.tail

	def __iter__(self):
		return sentence_list_iterator(self)

class sentence_list_iterator(object):
	def __init__(self,sentences):
		self.item=sentences.head

	def __iter__(self):
		return self

	def next(self):
		if self.item is None:
			raise StopIteration
		result=self.item
		self.item=self.item.next
		return result

class prompt_selector(object):
	def __init__(self,conf):
		self.unit_idx=unit_indexer()
		self.coverage_in_statements=conf["coverage_in_statements"]
		self.coverage_in_questions=conf["coverage_in_questions"]
		self.alphabet=set(conf["alphabet"])
		self.statements=sentence_list()
		self.questions=sentence_list()
		self.prompts=[]
		order=0
		print("Loading sentences")
		with codecs.open("sentences","r","utf-8") as f:
			for line in f:
				text=line.strip()
				if text:
					sent=sentence(text,order)
					order+=1
					if self.is_question(sent.text) and self.coverage_in_questions>0:
						self.questions.append(sent)
					else:
						self.statements.append(sent)
		print("Determining units")
		self.determine_units(self.statements)
		self.determine_units(self.questions)
		print("Normalizing units")
		self.normalize_all_sentence_units()

	def is_question(self,sent):
		for c in sent[-1::-1]:
			if c=="?":
				return True
			elif c.isspace():
				return False
		return False

	def normalize_all_sentence_units(self):
		for sent in self.statements:
			self.normalize_sentence_units(sent)
		for sent in self.questions:
			self.normalize_sentence_units(sent)

	def normalize_sentence_units(self,sent):
		sent.cost=len(sent.units)
		counter=collections.Counter()
		for u in sent.units:
			i=self.unit_idx.add(u)
			counter[i]+=1
		sent.instances=list(counter.iteritems())
		sent.units=[i for i,n in sent.instances]
		sent.units.sort()

	def order_prompts(self):
		self.prompts.extend(self.statements)
		self.prompts.extend(self.questions)
		self.prompts.sort(key=lambda s: s.order)

	def find_rarest_unit(self,total_state,state,target_state):
		min_count=0
		unit=None
		for i in xrange(len(state)):
			if state[i]>=target_state[i]:
				continue
			if min_count==0 or total_state[i]<min_count:
				min_count=total_state[i]
				unit=i
		return unit

	def sentence_contains_unit(self,sent,u):
		i=bisect.bisect_left(sent.units,u)
		return ((i<len(sent.units)) and (sent.units[i]==u))

	def count_useful_instances(self,sent,state,target_state):
		count=0
		for u,n in sent.instances:
			r=target_state[u]-state[u]
			if r>0:
				count+=min(n,r)
		return count

	def find_most_useful_sentence(self,sentences,total_state,state,target_state):
		unit=self.find_rarest_unit(total_state,state,target_state)
		if unit is None:
			return None
		max_count=0
		min_cost=0
		res=None
		sent=sentences.first()
		next_sent=None
		while sent:
			next_sent=sent.next
			if self.sentence_contains_unit(sent,unit):
				count=self.count_useful_instances(sent,state,target_state)
				if count==0:
					sentences.remove(sent)
				elif (count>max_count) or ((count==max_count) and (sent.cost<min_cost)):
					res=sent
					max_count=count
					min_cost=sent.cost
			sent=next_sent
		return res

	def select(self,sentences,total_state,state,target_state):
		selected=sentence_list()
		it=0
		while True:
			it+=1
			print("Iteration {}".format(it))
			sent=self.find_most_useful_sentence(sentences,total_state,state,target_state)
			if sent is None:
				break
			sentences.remove(sent)
			selected.append(sent)
			self.update_state(state,sent,True)
		return selected

	def get_total_state(self,sentences):
		state=[0]*self.unit_idx.count()
		for sent in sentences:
			for i,n in sent.instances:
				state[i]+=n
		return state

	def get_target_state(self,total_state,target_coverage):
		state=list(total_state)
		for i in xrange(len(state)):
			if state[i]>target_coverage:
				state[i]=target_coverage
		return state

	def is_useless_sentence(self,sent,state,target_state):
		for i,n in sent.instances:
			k=state[i]-n
			if k<target_state[i]:
				return False
		return True

	def find_least_useful_sentence(self,sentences,state,target_state):
		max_cost=0
		res=None
		for sent in sentences:
			if self.is_useless_sentence(sent,state,target_state) and sent.cost>max_cost:
				res=sent
				max_cost=sent.cost
		return res

	def update_state(self,state,sent,add):
		for i,n in sent.instances:
			if add:
				state[i]+=n
			else:
				state[i]-=n

	def reduce(self,sentences,state,target_state):
		it=0
		while True:
			it+=1
			print("Iteration {}".format(it))
			sent=self.find_least_useful_sentence(sentences,state,target_state)
			if sent is None:
				return
			sentences.remove(sent)
			self.update_state(state,sent,False)

	def process(self,sentences,target_coverage):
		if sentences.empty():
			return
		total_state=self.get_total_state(sentences)
		state=[0]*self.unit_idx.count()
		target_state=self.get_target_state(total_state,target_coverage)
		print("Selecting")
		selected=self.select(sentences,total_state,state,target_state)
		print("Reducing")
		self.reduce(selected,state,target_state)
		return selected

	def __call__(self):
		print("Processing statements")
		self.statements=self.process(self.statements,self.coverage_in_statements)
		print("Processing questions")
		self.questions=self.process(self.questions,self.coverage_in_questions)
		self.order_prompts()
		return self.prompts

	def get_word(self,token):
		ltoken=token.lower()
		first=next((i for i in xrange(len(ltoken)) if ltoken[i] in self.alphabet),None)
		if first is None:
			return u""
		last=next((i for i in xrange(len(ltoken)-1,-1,-1) if ltoken[i] in self.alphabet))
		return ltoken[first:last+1]

class bigram_prompt_selector(prompt_selector):
	def __init__(self,conf):
		self.digraphs=set(conf["digraphs"])
		prompt_selector.__init__(self,conf)

	def get_words(self,token):
		ltoken=token.lower()
		b=[]
		prev_is_letter=False
		for i,c in enumerate(ltoken):
			is_letter=(c in self.alphabet)
			if is_letter:
				if not prev_is_letter:
					b.append([i,len(ltoken)])
			else:
				if prev_is_letter:
					b[-1][1]=i
			prev_is_letter=is_letter
		words=[]
		for s,e in b:
			words.append(ltoken[s:e])
		return words

	def get_word_symbols(self,word):
		symbols=[]
		n=len(word)
		l=n-1
		i=0
		while (i<n):
			if i<l:
				b=word[i:i+2]
				if b in self.digraphs:
					symbols.append(b)
					i+=2
				else:
					symbols.append(word[i])
					i+=1
			else:
				symbols.append(word[i])
				i+=1
		return symbols

	def get_sentence_symbols(self,sentence):
		symbols=["#"]
		for token in sentence.text.split():
			if symbols[-1]!="#" and token[0].lower() not in self.alphabet:
				symbols.append("#")
			for word in self.get_words(token):
				symbols.extend(self.get_word_symbols(word))
			if symbols[-1]!="#" and token[-1].lower() not in self.alphabet:
				symbols.append("#")
		if symbols[-1]!="#":
			symbols.append("#")
		return symbols

	def get_n_grams(self,symbols,n):
		if len(symbols)<n:
			return []
		n_grams=[]
		for i in xrange(len(symbols)-n+1):
				n_grams.append(u"".join(symbols[i:i+n]))
		return n_grams

	def get_sentence_units(self,sentence):
		return self.get_n_grams(self.get_sentence_symbols(sentence),2)

	def determine_units(self,sentences):
		for sent in sentences:
			sent.units.extend(self.get_sentence_units(sent))

class diphone_prompt_selector(prompt_selector):
	def __init__(self,conf):
		self.language=conf["language"]
		self.infile="sentences"
		prompt_selector.__init__(self,conf)

	def determine_units(self,sentences):
		with codecs.open("ssml","w","utf-8") as f_out:
			f_out.write('<speak xml:lang="{}">\n'.format(self.language))
			for sent in sentences:
				f_out.write(u"<s>{}</s>\n".format(sent))
			f_out.write("</speak>\n")
		subprocess.check_call(["RHVoice-transcribe-sentences","ssml","transcription"])
		with codecs.open("transcription","r","utf-8") as f_in:
			for s,l in zip(sentences,f_in):
				phones=l.split()
				s.extend(phones[i-1]+"+"+phones[i] for i in xrange(1,len(phones)))

def select_prompts(conf):
	sel=diphone_prompt_selector(conf) if conf["language"] else bigram_prompt_selector(conf)
	prompts=sel()
	with codecs.open("script.txt","w","utf-8") as f:
		f.write(u"\ufeff")
		for prompt in prompts:
			f.write(prompt.text)
			f.write("\r\n\r\n")
	words=set()
	for prompt in prompts:
		words.update((sel.get_word(token) for token in prompt.text.split()))
	with codecs.open("vocab","w","utf-8") as f:
		for word in sorted(words):
			f.write(word)
			f.write("\n")

if __name__=="__main__":
	parser=argparse.ArgumentParser(description="Select prompts for recording")
	parser.add_argument("--config",required=True,help="the path to the configuration file")
	args=parser.parse_args()
	with open(args.config,"r") as f:
		conf=json.load(f)
	select_prompts(conf)
