[irstlm] 03/126:
Giulio Paci
giuliopaci-guest at moszumanska.debian.org
Tue May 17 07:46:38 UTC 2016
This is an automated email from the git hooks/post-receive script.
giuliopaci-guest pushed a commit to annotated tag adaptiveLM.v0.1
in repository irstlm.
commit 55cf030f4b0de41049de8c996f3168080bf1eaa8
Author: Marcello Federico <mrcfdr at gmail.com>
Date: Mon Jul 20 09:18:13 2015 +0200
---
src/CMakeLists.txt | 68 ++
src/cmd.c | 1295 +++++++++++++++++++++
src/cmd.h | 119 ++
src/compile-lm.cpp | 526 +++++++++
src/cplsa.cpp | 641 +++++++++++
src/cplsa.h | 85 ++
src/crc.cpp | 75 ++
src/crc.h | 33 +
src/cswa.cpp | 214 ++++
src/cswam.cpp | 1484 ++++++++++++++++++++++++
src/cswam.h | 219 ++++
src/dict.cpp | 164 +++
src/dictionary.cpp | 577 ++++++++++
src/dictionary.h | 249 ++++
src/doc.cpp | 111 ++
src/doc.h | 43 +
src/dtsel.cpp | 397 +++++++
src/gzfilebuf.h | 90 ++
src/htable.cpp | 107 ++
src/htable.h | 270 +++++
src/index.h | 18 +
src/interplm.cpp | 536 +++++++++
src/interplm.h | 158 +++
src/interpolate-lm.cpp | 564 +++++++++
src/linearlm.cpp | 233 ++++
src/linearlm.h | 55 +
src/lmContainer.cpp | 167 +++
src/lmContainer.h | 198 ++++
src/lmInterpolation.cpp | 243 ++++
src/lmInterpolation.h | 131 +++
src/lmclass.cpp | 236 ++++
src/lmclass.h | 104 ++
src/lmmacro.cpp | 903 +++++++++++++++
src/lmmacro.h | 133 +++
src/lmtable.cpp | 2948 +++++++++++++++++++++++++++++++++++++++++++++++
src/lmtable.h | 660 +++++++++++
src/mdiadapt.cpp | 2175 ++++++++++++++++++++++++++++++++++
src/mdiadapt.h | 158 +++
src/mempool.cpp | 505 ++++++++
src/mempool.h | 194 ++++
src/mfstream.cpp | 219 ++++
src/mfstream.h | 218 ++++
src/mixture.cpp | 581 ++++++++++
src/mixture.h | 95 ++
src/n_gram.cpp | 299 +++++
src/n_gram.h | 129 +++
src/ngramcache.cpp | 159 +++
src/ngramcache.h | 93 ++
src/ngramtable.cpp | 1870 ++++++++++++++++++++++++++++++
src/ngramtable.h | 379 ++++++
src/ngt.cpp | 506 ++++++++
src/normcache.cpp | 123 ++
src/normcache.h | 53 +
src/plsa.cpp | 250 ++++
src/prune-lm.cpp | 175 +++
src/quantize-lm.cpp | 512 ++++++++
src/score-lm.cpp | 122 ++
src/shiftlm.cpp | 830 +++++++++++++
src/shiftlm.h | 108 ++
src/stream-tlm.cpp | 575 +++++++++
src/thpool.c | 551 +++++++++
src/thpool.h | 166 +++
src/timer.cpp | 109 ++
src/timer.h | 35 +
src/tlm.cpp | 586 ++++++++++
src/util.cpp | 369 ++++++
src/util.h | 97 ++
src/verify-caching.cpp | 91 ++
68 files changed, 26386 insertions(+)
diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt
new file mode 100644
index 0000000..a23784e
--- /dev/null
+++ b/src/CMakeLists.txt
@@ -0,0 +1,68 @@
+# Set output directory
+SET(EXECUTABLE_OUTPUT_PATH ${CMAKE_INSTALL_PREFIX}/bin)
+SET(LIBRARY_OUTPUT_PATH ${CMAKE_INSTALL_PREFIX}/lib)
+
+ADD_DEFINITIONS("-D_LARGE_FILES")
+ADD_DEFINITIONS("-D_FILE_OFFSET_BITS=64")
+ADD_DEFINITIONS("-DMYCODESIZE=3")
+ADD_DEFINITIONS("-DDEBUG")
+ADD_DEFINITIONS("-DTRACE_LEVEL=1")
+
+if (CXX0)
+ MESSAGE( STATUS "HAVE_CXX0=true; hence, variable HAVE_CXX0 is set" )
+ SET(STD_FLAG "-std=c++0x")
+ ADD_DEFINITIONS("-DHAVE_CXX0")
+else()
+ MESSAGE( STATUS "HAVE_CXX0=false; hence, variable HAVE_CXX0 is unset" )
+ SET(STD_FLAG "")
+ ADD_DEFINITIONS("-UHAVE_CXX0")
+endif()
+
+SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -g ${STD_FLAG} -isystem/usr/include -W -Wall -ffor-scope")
+SET(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS}")
+
+INCLUDE_DIRECTORIES("${PROJECT_SOURCE_DIR}/src")
+
+SET( LIB_IRSTLM_SRC
+ cmd.h cmd.c
+ thpool.h thpool.c
+ gzfilebuf.h index.h
+ dictionary.h dictionary.cpp
+ htable.h htable.cpp
+ lmContainer.h lmContainer.cpp
+ lmclass.h lmclass.cpp
+ lmmacro.h lmmacro.cpp
+ lmtable.h lmtable.cpp
+ lmInterpolation.h lmInterpolation.cpp
+ mempool.h mempool.cpp
+ mfstream.h mfstream.cpp
+ n_gram.h n_gram.cpp
+ ngramcache.h ngramcache.cpp
+ ngramtable.h ngramtable.cpp
+ timer.h timer.cpp
+ util.h util.cpp
+ crc.h crc.cpp
+ interplm.h interplm.cpp
+ linearlm.h linearlm.cpp
+ mdiadapt.h mdiadapt.cpp
+ mixture.h mixture.cpp
+ normcache.h normcache.cpp
+ shiftlm.h shiftlm.cpp
+ cplsa.h cplsa.cpp
+ cswam.h cswam.cpp
+ doc.h doc.cpp
+)
+
+ADD_LIBRARY(irstlm STATIC ${LIB_IRSTLM_SRC})
+LINK_DIRECTORIES (${LIBRARY_OUTPUT_PATH})
+
+FOREACH(CMD dict ngt tlm dtsel plsa cswa compile-lm interpolate-lm prune-lm quantize-lm score-lm)
+
+ADD_EXECUTABLE(${CMD} ${CMD}.cpp)
+TARGET_LINK_LIBRARIES (${CMD} irstlm -lm -lz -lpthread)
+
+ENDFOREACH()
+
+#INSTALL INCLUDE FILES
+FILE(GLOB includes src *.h)
+INSTALL(FILES ${includes} DESTINATION include)
diff --git a/src/cmd.c b/src/cmd.c
new file mode 100644
index 0000000..e95d964
--- /dev/null
+++ b/src/cmd.c
@@ -0,0 +1,1295 @@
+/******************************************************************************
+ IrstLM: IRST Language Model Toolkit
+ Copyright (C) 2006 Marcello Federico, ITC-irst Trento, Italy
+
+ This library is free software; you can redistribute it and/or
+ modify it under the terms of the GNU Lesser General Public
+ License as published by the Free Software Foundation; either
+ version 2.1 of the License, or (at your option) any later version.
+
+ This library is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+ Lesser General Public License for more details.
+
+ You should have received a copy of the GNU Lesser General Public
+ License along with this library; if not, write to the Free Software
+ Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+
+ ******************************************************************************/
+
+#ifndef _WIN32_WCE
+#include <stdio.h>
+#endif
+
+#include <stdlib.h>
+#include <ctype.h>
+#include <string.h>
+#include <stdarg.h>
+#if defined(_WIN32)
+#include <windows.h>
+#else
+#include <unistd.h>
+#endif
+
+#ifdef USE_UPIO
+#include "missing.h"
+#include "updef.h"
+#endif
+
+#include "cmd.h"
+
+#ifdef NEEDSTRDUP
+char *strdup(const char *s);
+#endif
+
+#define LINSIZ 10240
+
+
+static Bool_T BoolEnum[] = {
+ { (char*)"FALSE", FALSE},
+ { (char*)"TRUE", TRUE},
+ { (char*)"false", FALSE},
+ { (char*)"true", TRUE},
+ { (char*)"0", FALSE},
+ { (char*)"1", TRUE},
+ { (char*)"NO", FALSE},
+ { (char*)"YES", TRUE},
+ { (char*)"No", FALSE},
+ { (char*)"Yes", TRUE},
+ { (char*)"no", FALSE},
+ { (char*)"yes", TRUE},
+ { (char*)"N", FALSE},
+ { (char*)"Y", TRUE},
+ { (char*)"n", FALSE},
+ { (char*)"y", TRUE},
+ END_ENUM
+};
+
+
+static char
+*GetLine(FILE *fp,
+ int n,
+ char *Line),
+**str2array(char *s,
+ char *sep);
+static int
+str2narray(int type,
+ char *s,
+ char *sep,
+ void **a);
+
+static int
+Scan(char *ProgName,
+ Cmd_T *cmds,
+ char *Line),
+SetParam(Cmd_T *cmd,
+ char *s),
+SetEnum(Cmd_T *cmd,
+ char *s),
+SetBool(Cmd_T *cmd,
+ char *s),
+SetFlag(Cmd_T *cmd,
+ char *s),
+SetSubrange(Cmd_T *cmd,
+ char *s),
+SetStrArray(Cmd_T *cmd,
+ char *s),
+SetNumArray(Cmd_T *cmd,
+ char *s),
+SetGte(Cmd_T *cmd,
+ char *s),
+SetLte(Cmd_T *cmd,
+ char *s),
+CmdError(char *opt),
+EnumError(Cmd_T *cmd,
+ char *s),
+BoolError(Cmd_T *cmd,
+ char *s),
+SubrangeError(Cmd_T *cmd,
+ int n),
+GteError(Cmd_T *cmd,
+ int n),
+LteError(Cmd_T *cmd,
+ int n),
+PrintParam(Cmd_T *cmd,
+ int TypeFlag,
+ int ValFlag,
+ FILE *fp),
+PrintParams4(int TypeFlag,
+ int ValFlag,
+ int MsgFlag,
+ FILE *fp),
+FreeParam(Cmd_T *cmd),
+PrintEnum(Cmd_T *cmd,
+ int TypeFlag,
+ int ValFlag,
+ FILE *fp),
+PrintBool(Cmd_T *cmd,
+ int TypeFlag,
+ int ValFlag,
+ FILE *fp),
+PrintFlag(Cmd_T *cmd,
+ int TypeFlag,
+ int ValFlag,
+ FILE *fp),
+PrintStrArray(Cmd_T *cmd,
+ int TypeFlag,
+ int ValFlag,
+ FILE *fp),
+PrintIntArray(Cmd_T *cmd,
+ int TypeFlag,
+ int ValFlag,
+ FILE *fp),
+PrintDblArray(Cmd_T *cmd,
+ int TypeFlag,
+ int ValFlag,
+ FILE *fp),
+BuildCmdList(Cmd_T **cmdp,
+ int *cmdSz,
+ char *ParName,
+ va_list args),
+StoreCmdLine(char *s);
+
+static Cmd_T *pgcmds = 0;
+static int pgcmdN = 0;
+static int pgcmdSz = 0;
+static char *SepString = " \t\r\n";
+static char *ProgName = 0;
+static char **CmdLines = 0;
+static int CmdLinesSz = 0,
+CmdLinesL = 0;
+
+int
+DeclareParams(char *ParName,
+ ...)
+{
+ va_list args;
+
+ va_start(args, ParName);
+ pgcmdN = BuildCmdList(&pgcmds, &pgcmdSz, ParName, args);
+ va_end(args);
+ return 0;
+}
+
+int
+GetParams(int *n,
+ char ***a,
+ char *DefCmd)
+{
+ char *Line;
+ int i,
+ argc = *n;
+ char **argv = *a,
+ *s,
+ *p,
+ *defCmd;
+
+#if defined(MSDOS)||defined(_WIN32)
+ char *dot = 0;
+#endif
+ extern char **environ;
+
+ if(!(Line=malloc(LINSIZ))) {
+ fprintf(stderr, "GetParams(): Unable to alloc %d bytes\n",
+ LINSIZ);
+ exit(IRSTLM_CMD_ERROR_MEMORY);
+ }
+ for(ProgName=*argv+strlen(*argv);
+ ProgName-->*argv && *ProgName!='/' && *ProgName!='\\';);
+ ++ProgName;
+#if defined(MSDOS)||defined(_WIN32)
+ if((dot=strchr(ProgName, '.'))) *dot=0;
+#endif
+ --argc;
+ ++argv;
+ for(i=0; environ[i]; i++) {
+ if(strncmp(environ[i], "cmd_", 4)) continue;
+ strcpy(Line, environ[i]+4);
+ if(!(p=strchr(Line, '='))) continue;
+ *p=' ';
+ StoreCmdLine(Line);
+ if(Scan(ProgName, pgcmds, Line)) CmdError(environ[i]);
+ }
+ if((defCmd=DefCmd?(DefCmd=strdup(DefCmd)):0)) {
+ defCmd += strspn(defCmd, "\n\r");
+ }
+ for(;;) {
+ char *CmdFile = NULL;
+ if(argc && argv[0][0]=='-' && argv[0][1]=='=') {
+ CmdFile = argv[0]+2;
+ ++argv;
+ --argc;
+ defCmd = 0;
+ }
+ if(!CmdFile) {
+ int i;
+ char ch;
+
+ if(!defCmd||!(i=strcspn(defCmd, "\n\r"))) break;
+ ch = defCmd[i];
+ defCmd[i] = 0;
+ CmdFile = defCmd;
+ defCmd += i+!!ch;
+ defCmd += strspn(defCmd, "\n\r");
+ }
+
+ int IsPipe = !strncmp(CmdFile, "@@", 2);
+
+ FILE *fp = IsPipe
+ ? popen(CmdFile+2, "r")
+ : strcmp(CmdFile, "-")
+ ? fopen(CmdFile, "r")
+ : stdin;
+
+
+ if(!fp) {
+ if(defCmd) continue;
+ fprintf(stderr, "Unable to open command file %s\n", CmdFile);
+ exit(IRSTLM_CMD_ERROR_IO);
+ }
+ while(GetLine(fp, LINSIZ, Line) && strcmp(Line, "\\End")) {
+ StoreCmdLine(Line);
+ if(Scan(ProgName, pgcmds, Line)) CmdError(Line);
+ }
+
+ if(fp!=stdin) {
+ if(IsPipe)
+ pclose(fp);
+ else
+ fclose(fp);
+ }
+ }
+ if(DefCmd) free(DefCmd);
+
+ // while(argc && **argv=='-'){
+ while(argc){
+ if (**argv=='-'){
+ s=strchr(*argv, '=');
+
+ //allows double dash for parameters
+ int dash_number=1;
+ if (*(*argv+1) == '-') dash_number++;
+ if (s){
+ *s = ' ';
+ if((p=strchr(*argv+dash_number, '.'))&&p<s) {
+ strcpy(Line, *argv+dash_number);
+ } else {
+ sprintf(Line, "%s/%s", ProgName, *argv+dash_number);
+ }
+ *s = '=';
+ }else{ //force the true value for the parameters without a value
+ sprintf(Line, "%s/%s", ProgName, *argv+dash_number);
+ }
+
+ StoreCmdLine(Line);
+ if(Scan(ProgName, pgcmds, Line)) CmdError(*argv);
+ --argc;
+ ++argv;
+ }else{ //skip tokens not starting with '-'
+ --argc;
+ ++argv;
+ }
+ }
+ *n = argc;
+ *a = argv;
+
+#if defined(MSDOS)||defined(_WIN32)
+ if(dot) *dot = '.';
+#endif
+ free(Line);
+ return 0;
+}
+
+int
+GetDotParams(char *ParName,
+ ...)
+{
+ va_list args;
+ int j,
+ cmdN,
+ cmdSz = 0;
+ Cmd_T *cmds = 0;
+
+ va_start(args, ParName);
+ cmdN = BuildCmdList(&cmds, &cmdSz, ParName, args);
+ va_end(args);
+ for(j=0; j<CmdLinesL; j++) Scan(ProgName, cmds, CmdLines[j]);
+ for(j=0; j<cmdN; j++) FreeParam(cmds+j);
+ if(cmds) free(cmds);
+ return 0;
+}
+
+int
+GetStrParams(char **lines,
+ int n,
+ char *parName,
+ ...)
+{
+ va_list args;
+ int j,
+ cmdN,
+ cmdSz = 0;
+ Cmd_T *cmds = 0;
+
+ va_start(args, parName);
+ cmdN = BuildCmdList(&cmds, &cmdSz, parName, args);
+ va_end(args);
+ for(j=0; j<n; j++) Scan((char*)0, cmds, lines[j]);
+ for(j=0; j<cmdN; j++) FreeParam(cmds+j);
+ if(cmds) free(cmds);
+ return 0;
+}
+
+int
+PrintParams(int ValFlag,
+ FILE *fp)
+{
+ int TypeFlag=0;
+ int MsgFlag=1;
+ return PrintParams4(TypeFlag, ValFlag, MsgFlag, fp);
+}
+
+int
+FullPrintParams(int TypeFlag,
+ int ValFlag,
+ int MsgFlag,
+ FILE *fp)
+{
+ return PrintParams4(TypeFlag, ValFlag, MsgFlag, fp);
+}
+
+static int
+PrintParams4(int TypeFlag,
+ int ValFlag,
+ int MsgFlag,
+ FILE *fp)
+{
+ int i;
+
+ fflush(fp);
+ if(ValFlag) {
+ fprintf(fp, "Parameters Values:\n");
+ } else {
+ fprintf(fp, "Parameters:\n");
+ }
+ for(i=0; pgcmds[i].Name; i++) {
+ PrintParam(pgcmds+i, TypeFlag, ValFlag, fp);
+ if(MsgFlag&&pgcmds[i].Msg) {
+ char *s=pgcmds[i].Msg,
+ *p;
+ for(;(p=strchr(s, '\n')); s=++p) {
+ fprintf(fp, "%6s%*.*s\n", "", (int)(p-s), (int)(p-s), s);
+ }
+ if(s) fprintf(fp, "%6s%s", "", s);
+ }
+ fprintf(fp, "\n");
+ }
+ fprintf(fp, "\n");
+ fflush(fp);
+ return 0;
+}
+
+int
+SPrintParams(char ***a,
+ char *pfx)
+{
+ int l,
+ n;
+ Cmd_T *cmd;
+
+ if(!pfx) pfx="";
+ l = strlen(pfx);
+ for(n=0, cmd=pgcmds; cmd->Name; cmd++) n += !!cmd->ArgStr;
+ a[0] = calloc(n, sizeof(char*));
+ for(n=0, cmd=pgcmds; cmd->Name; cmd++) {
+ if(!cmd->ArgStr) continue;
+ a[0][n] = malloc(strlen(cmd->Name)+strlen(cmd->ArgStr)+l+2);
+ sprintf(a[0][n], "%s%s=%s", pfx, cmd->Name, cmd->ArgStr);
+ ++n;
+ }
+ return n;
+}
+
+static int
+BuildCmdList(Cmd_T **cmdp,
+ int *cmdSz,
+ char *ParName,
+ va_list args)
+{
+ int j,
+ c,
+ cmdN=0;
+ char *s;
+ Cmd_T *cmd,
+ *cmds;
+
+ if(!*cmdSz) {
+ if(!(cmds=*cmdp=malloc((1+(*cmdSz=BUFSIZ))*sizeof(Cmd_T)))) {
+ fprintf(stderr, "BuildCmdList(): malloc() failed\n");
+ exit(IRSTLM_CMD_ERROR_MEMORY);
+ }
+ } else {
+ for(cmds=*cmdp; cmds[cmdN].Name; ++cmdN);
+ }
+ while(ParName) {
+ if(cmdN==*cmdSz) {
+ cmds=*cmdp=realloc(cmds,
+ (1+(*cmdSz+=BUFSIZ))*sizeof(Cmd_T));
+ if(!cmds) {
+ fprintf(stderr,
+ "BuildCmdList(): realloc() failed\n");
+ exit(IRSTLM_CMD_ERROR_MEMORY);
+ }
+ }
+ for(j=0; j<cmdN&&strcmp(cmds[j].Name, ParName)<0; j++);
+ for(c=cmdN; c>j; c--) cmds[c] = cmds[c-1];
+ cmd = cmds+j;
+ cmd->Name = ParName;
+ cmd->Type = va_arg(args, int);
+ cmd->Val = va_arg(args, void*);
+ cmd->Msg = 0;
+ cmd->Flag = 0;
+ cmd->p = 0;
+
+ switch(cmd->Type&~CMDMSG) {
+ case CMDENUMTYPE: /* get the pointer to Enum_T struct */
+ case CMDFLAGTYPE:
+ cmd->p = va_arg(args, void*);
+ break;
+ case CMDSUBRANGETYPE: /* get the two limits */
+ cmd->p = (void*)calloc(2, sizeof(int));
+ ((int*)cmd->p)[0] = va_arg(args, int);
+ ((int*)cmd->p)[1] = va_arg(args, int);
+ break;
+ case CMDGTETYPE: /* lower or upper bound */
+ case CMDLTETYPE:
+ cmd->p = (void*)calloc(1, sizeof(int));
+ ((int*)cmd->p)[0] = va_arg(args, int);
+ break;
+ case CMDSTRARRAYTYPE: /* separator string */
+ cmd->p = (s=va_arg(args, char*)) ? (void*)strdup(s) : 0;
+ break;
+ case CMDDBLARRAYTYPE:
+ case CMDINTARRAYTYPE: /* separator & pointer to length */
+ cmd->p = (void*)calloc(2, sizeof(void*));
+ s = va_arg(args, char*);
+ ((char**)cmd->p)[0] = s ? strdup(s) : 0;
+ ((int**)cmd->p)[1] = va_arg(args, int*);
+ *((int**)cmd->p)[1] = 0;
+ break;
+ case CMDBOOLTYPE:
+ cmd->p = BoolEnum;
+ break;
+ //cmd->p = (Bool_T*)calloc(1, sizeof(Bool_T));
+ // cmd->p = va_arg(args, void*);
+// cmd->p = BoolEnum;
+ case CMDDOUBLETYPE: /* nothing else is needed */
+ case CMDFLOATTYPE:
+ case CMDINTTYPE:
+ case CMDSTRINGTYPE:
+ break;
+ default:
+ fprintf(stderr, "%s: %s %d %s \"%s\"\n",
+ "BuildCmdList()", "Unknown Type",
+ cmd->Type&~CMDMSG, "for parameter", cmd->Name);
+ exit(IRSTLM_CMD_ERROR_DATA);
+ }
+ if(cmd->Type&CMDMSG) {
+ cmd->Type&=~CMDMSG;
+ cmd->Msg = va_arg(args, char*);
+ }
+ cmdN++;
+ ParName = va_arg(args, char*);
+ }
+ cmds[cmdN].Name = 0;
+ return cmdN;
+}
+
+static int
+CmdError(char *opt)
+{
+ fprintf(stderr, "Invalid option \"%s\"\n", opt);
+ fprintf(stderr, "This program expects the following parameters:\n");
+ PrintParams4(TRUE, FALSE, TRUE, stderr);
+ exit(IRSTLM_CMD_ERROR_DATA);
+ return 0;
+}
+
+static int
+FreeParam(Cmd_T *cmd)
+{
+ switch(cmd->Type) {
+ case CMDBOOLTYPE2:
+ case CMDSUBRANGETYPE:
+ case CMDGTETYPE:
+ case CMDLTETYPE:
+ case CMDSTRARRAYTYPE:
+ if(cmd->p) free(cmd->p);
+ break;
+ case CMDINTARRAYTYPE:
+ case CMDDBLARRAYTYPE:
+ if(!cmd->p) break;
+ if(*(char**)cmd->p) free(*(char**)cmd->p);
+ free(cmd->p);
+ break;
+ }
+ return 0;
+}
+
+static int
+PrintParam(Cmd_T *cmd,
+ int TypeFlag,
+ int ValFlag,
+ FILE *fp)
+{
+ char ts[128];
+
+ *ts=0;
+ fprintf(fp, "%4s", "");
+ switch(cmd->Type) {
+ case CMDDOUBLETYPE:
+ fprintf(fp, "%s", cmd->Name);
+ if(TypeFlag) fprintf(fp, " [double]");
+ if(ValFlag) fprintf(fp, ": %22.15e", *(double*)cmd->Val);
+ break;
+ case CMDFLOATTYPE:
+ fprintf(fp, "%s", cmd->Name);
+ if(TypeFlag) fprintf(fp, " [float]");
+ if(ValFlag) fprintf(fp, ": %22.15e", *(float *)cmd->Val);
+ break;
+ case CMDBOOLTYPE2:
+ case CMDBOOLTYPE:
+ PrintBool(cmd, TypeFlag, ValFlag, fp);
+ break;
+ case CMDENUMTYPE:
+ PrintEnum(cmd, TypeFlag, ValFlag, fp);
+ break;
+ case CMDFLAGTYPE:
+ PrintFlag(cmd, TypeFlag, ValFlag, fp);
+ break;
+ case CMDINTTYPE:
+ if(TypeFlag) sprintf(ts, " [int]");
+ case CMDSUBRANGETYPE:
+ if(TypeFlag&&!*ts) sprintf(ts, " [int %d ... %d]",
+ ((int*)cmd->p)[0],
+ ((int*)cmd->p)[1]);
+ case CMDGTETYPE:
+ if(TypeFlag&&!*ts) sprintf(ts, " [int >= %d]",
+ ((int*)cmd->p)[0]);
+ case CMDLTETYPE:
+ if(TypeFlag&&!*ts) sprintf(ts, " [int <= %d]",
+ ((int*)cmd->p)[0]);
+ fprintf(fp, "%s", cmd->Name);
+ if(*ts) fprintf(fp, " %s", ts);
+ if(ValFlag) fprintf(fp, ": %d", *(int*)cmd->Val);
+ break;
+ case CMDSTRINGTYPE:
+ fprintf(fp, "%s", cmd->Name);
+ if(TypeFlag) fprintf(fp, " [string]");
+ if(ValFlag) {
+ if(*(char **)cmd->Val) {
+ fprintf(fp, ": \"%s\"", *(char**)cmd->Val);
+ } else {
+ fprintf(fp, ": %s", "NULL");
+ }
+ }
+ break;
+ case CMDSTRARRAYTYPE:
+ PrintStrArray(cmd, TypeFlag, ValFlag, fp);
+ break;
+ case CMDINTARRAYTYPE:
+ PrintIntArray(cmd, TypeFlag, ValFlag, fp);
+ break;
+ case CMDDBLARRAYTYPE:
+ PrintDblArray(cmd, TypeFlag, ValFlag, fp);
+ break;
+ default:
+ fprintf(stderr, "%s: %s %d %s \"%s\"\n",
+ "PrintParam",
+ "Unknown Type",
+ cmd->Type,
+ "for parameter",
+ cmd->Name);
+ exit(IRSTLM_CMD_ERROR_DATA);
+ }
+ fprintf(fp, ":");
+ // fprintf(fp, "\n");
+ fflush(fp);
+ return 0;
+}
+
+static char *
+GetLine(FILE *fp,
+ int n,
+ char *Line)
+{
+ int j,
+ l,
+ offs=0;
+
+ for(;;) {
+ if(!fgets(Line+offs, n-offs, fp)) {
+ return 0;
+ }
+ if(Line[offs]=='#') continue;
+ l = strlen(Line+offs)-1;
+ Line[offs+l] = 0;
+ for(j=offs; Line[j]&&isspace((unsigned char)Line[j]); j++,l--);
+ if(l<1) continue;
+ if(j > offs) {
+ char *s = Line+offs,
+ *q = Line+j;
+
+ while((*s++=*q++))
+ ;
+ }
+ if(Line[offs+l-1]=='\\') {
+ offs += l;
+ Line[offs-1] = ' ';
+ } else {
+ break;
+ }
+ }
+ return Line;
+}
+
+static int
+Scan(char *ProgName,
+ Cmd_T *cmds,
+ char *Line)
+{
+ char *q,
+ *p;
+ int i,
+ hl,
+ HasToMatch = FALSE,
+ c0,
+ c;
+
+ p = Line+strspn(Line, SepString);
+ if(!(hl=strcspn(p, SepString))) return 0;
+ if(ProgName&&(q=strchr(p, '/')) && q-p<hl) {
+ *q = 0;
+ if(strcmp(p, ProgName)) {
+ *q = '/';
+ return 0;
+ }
+ *q = '/';
+ HasToMatch=TRUE;
+ p = q+1;
+ }
+ if(!(hl=strcspn(p, SepString))) return 0;
+ c0 = p[hl];
+ p[hl] = 0;
+ for(i=0, c=1; cmds[i].Name&&(c=strcmp(cmds[i].Name, p))<0; i++)
+ ;
+ p[hl] = c0;
+ if(!c) return SetParam(cmds+i, p+hl+strspn(p+hl, SepString));
+ return HasToMatch && c;
+}
+
+static int
+SetParam(Cmd_T *cmd,
+ char *_s)
+{
+ char *s;
+
+ if(!*_s && cmd->Type==CMDENUMTYPE && cmd->Flag==1){
+ s=(char*) malloc(5);
+ strcpy(s,"TRUE");
+ }else{
+ s=_s;
+ }
+
+ if (!*s || (s=='\0' && cmd->Flag==0)){
+ fprintf(stderr,
+ "WARNING: No value specified for parameter \"%s\"\n",
+ cmd->Name);
+ return 0;
+ }
+
+ switch(cmd->Type) {
+ case CMDDOUBLETYPE:
+ if(sscanf(s, "%lf", (double*)cmd->Val)!=1) {
+ fprintf(stderr,
+ "Float value required for parameter \"%s\"\n",
+ cmd->Name);
+ exit(IRSTLM_CMD_ERROR_DATA);
+ }
+ break;
+ case CMDFLOATTYPE:
+ if(sscanf(s, "%f", (float*)cmd->Val)!=1) {
+ fprintf(stderr,
+ "Float value required for parameter \"%s\"\n",
+ cmd->Name);
+ exit(IRSTLM_CMD_ERROR_DATA);
+ }
+ break;
+ case CMDBOOLTYPE2:
+ case CMDBOOLTYPE:
+ SetBool(cmd, s);
+ break;
+ case CMDENUMTYPE:
+ SetEnum(cmd, s);
+ break;
+ case CMDFLAGTYPE:
+ SetFlag(cmd, s);
+ break;
+ case CMDINTTYPE:
+ if(sscanf(s, "%i", (int*)cmd->Val)!=1) {
+ fprintf(stderr,
+ "Integer value required for parameter \"%s\"\n",
+ cmd->Name);
+ exit(IRSTLM_CMD_ERROR_DATA);
+ }
+ break;
+ case CMDSTRINGTYPE:
+ *(char **)cmd->Val = (strcmp(s, "<NULL>") && strcmp(s, "NULL"))
+ ? strdup(s)
+ : 0;
+ break;
+ case CMDSTRARRAYTYPE:
+ SetStrArray(cmd, s);
+ break;
+ case CMDINTARRAYTYPE:
+ case CMDDBLARRAYTYPE:
+ SetNumArray(cmd, s);
+ break;
+ case CMDGTETYPE:
+ SetGte(cmd, s);
+ break;
+ case CMDLTETYPE:
+ SetLte(cmd, s);
+ break;
+ case CMDSUBRANGETYPE:
+ SetSubrange(cmd, s);
+ break;
+ default:
+ fprintf(stderr, "%s: %s %d %s \"%s\"\n",
+ "SetParam",
+ "Unknown Type",
+ cmd->Type,
+ "for parameter",
+ cmd->Name);
+ exit(IRSTLM_CMD_ERROR_DATA);
+ }
+ cmd->ArgStr = strdup(s);
+
+ if(!*_s && cmd->Type==CMDENUMTYPE && cmd->Flag==1){
+ free (s);
+ }
+
+ return 0;
+}
+
+static int
+SetBool(Cmd_T *cmd,
+ char *s)
+{
+ Bool_T *en;
+
+ for(en=(Bool_T*)cmd->p; en->Name; en++) {
+ if(*en->Name && !strcmp(s, en->Name)) {
+ *(char*)cmd->Val = en->Idx;
+ return 0;
+ }
+ }
+ return BoolError(cmd, s);
+}
+
+
+static int
+SetEnum(Cmd_T *cmd,
+ char *s)
+{
+ Enum_T *en;
+
+ for(en=(Enum_T*)cmd->p; en->Name; en++) {
+ if(*en->Name && !strcmp(s, en->Name)) {
+ *(int*)cmd->Val = en->Idx;
+ return 0;
+ }
+ }
+ return EnumError(cmd, s);
+}
+
+int
+EnumIdx(Enum_T *en,
+ char *s)
+{
+ if(en) for(; en->Name; en++) {
+ if(*en->Name && !strcmp(s, en->Name)) return en->Idx;
+ }
+ return -1;
+}
+
+char
+BoolIdx(Bool_T *en,
+ char *s)
+{
+ if(en) for(; en->Name; en++) {
+ if(*en->Name && !strcmp(s, en->Name)) return en->Idx;
+ }
+ return -1;
+}
+
+char *
+EnumStr(Enum_T *en,
+ int i)
+{
+ if(en) for(; en->Name; en++) if(en->Idx==i) return en->Name;
+ return 0;
+}
+
+char *
+BoolStr(Bool_T *en,
+ int i)
+{
+ if(en) for(; en->Name; en++) if(en->Idx==i) return en->Name;
+ return 0;
+}
+
+static int
+SetFlag(Cmd_T *cmd,
+ char *s)
+{
+ Enum_T *en;
+ int l;
+
+ for(; (l=strcspn(s, "+"))>0; s+=l,s+=!!*s) {
+ for(en=(Enum_T*)cmd->p;
+ en->Name&&(l!=strlen(en->Name)||strncmp(s, en->Name, l));
+ en++);
+ if(!en->Name) return EnumError(cmd, s);
+ *(int*)cmd->Val |= en->Idx;
+ }
+ return 0;
+}
+
+static int
+SetSubrange(Cmd_T *cmd,
+ char *s)
+{
+ int n;
+
+ if(sscanf(s, "%i", &n)!=1) {
+ fprintf(stderr,
+ "Integer value required for parameter \"%s\"\n",
+ cmd->Name);
+ exit(IRSTLM_CMD_ERROR_DATA);
+ }
+ if(n < *(int*)cmd->p || n > *((int*)cmd->p+1)) {
+ return SubrangeError(cmd, n);
+ }
+ *(int*)cmd->Val = n;
+ return 0;
+}
+
+static int
+SetGte(Cmd_T *cmd,
+ char *s)
+{
+ int n;
+
+ if(sscanf(s, "%i", &n)!=1) {
+ fprintf(stderr,
+ "Integer value required for parameter \"%s\"\n",
+ cmd->Name);
+ exit(IRSTLM_CMD_ERROR_DATA);
+ }
+ if(n<*(int*)cmd->p) {
+ return GteError(cmd, n);
+ }
+ *(int*)cmd->Val = n;
+ return 0;
+}
+
+static int
+SetStrArray(Cmd_T *cmd,
+ char *s)
+{
+ *(char***)cmd->Val = str2array(s, (char*)cmd->p);
+ return 0;
+}
+
+static int
+SetNumArray(Cmd_T *cmd,
+ char *s)
+{
+ *((int**)cmd->p)[1] = str2narray(cmd->Type, s,
+ *((char**)cmd->p), cmd->Val);
+ return 0;
+}
+
+static int
+SetLte(Cmd_T *cmd,
+ char *s)
+{
+ int n;
+
+ if(sscanf(s, "%i", &n)!=1) {
+ fprintf(stderr,
+ "Integer value required for parameter \"%s\"\n",
+ cmd->Name);
+ exit(IRSTLM_CMD_ERROR_DATA);
+ }
+ if(n > *(int*)cmd->p) {
+ return LteError(cmd, n);
+ }
+ *(int*)cmd->Val = n;
+ return 0;
+}
+
+static int
+EnumError(Cmd_T *cmd,
+ char *s)
+{
+ Enum_T *en;
+
+ fprintf(stderr,
+ "Invalid value \"%s\" for parameter \"%s\"\n", s, cmd->Name);
+ fprintf(stderr, "Valid values are:\n");
+ for(en=(Enum_T*)cmd->p; en->Name; en++) {
+ if(*en->Name) fprintf(stderr, " %s\n", en->Name);
+ }
+ fprintf(stderr, "\n");
+ exit(IRSTLM_CMD_ERROR_DATA);
+ return 0;
+}
+
+static int
+BoolError(Cmd_T *cmd,
+ char *s)
+{
+ Bool_T *en;
+
+ fprintf(stderr,
+ "Invalid value \"%s\" for parameter \"%s\"\n", s, cmd->Name);
+ fprintf(stderr, "Valid values are:\n");
+ for(en=(Bool_T*)cmd->p; en->Name; en++) {
+ if(*en->Name) fprintf(stderr, " %s\n", en->Name);
+ }
+ fprintf(stderr, "\n");
+ exit(IRSTLM_CMD_ERROR_DATA);
+ return 0;
+}
+
+static int
+GteError(Cmd_T *cmd,
+ int n)
+{
+ fprintf(stderr,
+ "Value %d out of range for parameter \"%s\"\n", n, cmd->Name);
+ fprintf(stderr, "Valid values must be greater than or equal to %d\n",
+ *(int*)cmd->p);
+ exit(IRSTLM_CMD_ERROR_DATA);
+ return 0;
+}
+
+static int
+LteError(Cmd_T *cmd,
+ int n)
+{
+ fprintf(stderr,
+ "Value %d out of range for parameter \"%s\"\n", n, cmd->Name);
+ fprintf(stderr, "Valid values must be less than or equal to %d\n",
+ *(int*)cmd->p);
+ exit(IRSTLM_CMD_ERROR_DATA);
+ return 0;
+}
+
+static int
+SubrangeError(Cmd_T *cmd,
+ int n)
+{
+ fprintf(stderr,
+ "Value %d out of range for parameter \"%s\"\n", n, cmd->Name);
+ fprintf(stderr, "Valid values range from %d to %d\n",
+ *(int*)cmd->p, *((int*)cmd->p+1));
+ exit(IRSTLM_CMD_ERROR_DATA);
+ return 0;
+}
+
+static int
+PrintEnum(Cmd_T *cmd,
+ int TypeFlag,
+ int ValFlag,
+ FILE *fp)
+{
+ Enum_T *en;
+
+ fprintf(fp, "%s", cmd->Name);
+ if(TypeFlag) {
+ fprintf(fp, " [enum { ");
+
+ char *sep="";
+
+ for(en=(Enum_T*)cmd->p; en->Name; en++) {
+ if(*en->Name) {
+ fprintf(fp, "%s%s", sep, en->Name);
+ sep=", ";
+ }
+ }
+ fprintf(fp, " }]");
+ }
+ if(ValFlag) {
+ for(en=(Enum_T*)cmd->p; en->Name; en++) {
+ if(*en->Name && en->Idx==*(int*)cmd->Val) {
+ fprintf(fp, ": %s", en->Name);
+ }
+ }
+ }
+ // fprintf(fp, "\n");
+ return 0;
+}
+
+static int
+PrintBool(Cmd_T *cmd,
+ int TypeFlag,
+ int ValFlag,
+ FILE *fp)
+{
+ Bool_T *en;
+
+ fprintf(fp, "%s", cmd->Name);
+ if(TypeFlag) {
+ fprintf(fp, " [enum { ");
+
+ char *sep="";
+
+ for(en=(Bool_T*)cmd->p; en->Name; en++) {
+ if(*en->Name) {
+ fprintf(fp, "%s%s", sep, en->Name);
+ sep=", ";
+ }
+ }
+ fprintf(fp, " }]");
+ }
+ if(ValFlag) {
+ for(en=(Bool_T*)cmd->p; en->Name; en++) {
+ if(*en->Name && en->Idx==*(int*)cmd->Val) {
+ fprintf(fp, ": %s", en->Name);
+ }
+ }
+ }
+ // fprintf(fp, "\n");
+ return 0;
+}
+
+static int
+PrintFlag(Cmd_T *cmd,
+ int TypeFlag,
+ int ValFlag,
+ FILE *fp)
+{
+ Enum_T *en;
+ char *sep="";
+
+ fprintf(fp, "%s", cmd->Name);
+ if(TypeFlag) {
+ fprintf(fp, ": flag { ");
+ for(en=(Enum_T*)cmd->p; en->Name; en++) {
+ if(*en->Name) {
+ fprintf(fp, "%s%s", sep, en->Name);
+ sep=", ";
+ }
+ }
+ fprintf(fp, " }");
+ }
+ if(ValFlag) {
+ fprintf(fp, ": ");
+ for(en=(Enum_T*)cmd->p; en->Name; en++) {
+ if(*en->Name && (en->Idx&*(int*)cmd->Val)==en->Idx) {
+ fprintf(fp, "%s%s", sep, en->Name);
+ sep="+";
+ }
+ }
+ }
+ fprintf(fp, "\n");
+ return 0;
+}
+
+static int
+PrintStrArray(Cmd_T *cmd,
+ int TypeFlag,
+ int ValFlag,
+ FILE *fp)
+{
+ char *indent,
+ **s = *(char***)cmd->Val;
+ int l = 4+strlen(cmd->Name);
+
+ fprintf(fp, "%s", cmd->Name);
+ if(TypeFlag) {
+ fprintf(fp, ": string array, separator \"%s\"",
+ cmd->p?(char*)cmd->p:"");
+ }
+ indent = malloc(l+2);
+ memset(indent, ' ', l+1);
+ indent[l+1] = 0;
+ if(ValFlag) {
+ fprintf(fp, ": %s", s ? (*s ? *s++ : "NULL") : "");
+ if(s) while(*s) {
+ fprintf(fp, "\n%s %s", indent, *s++);
+ }
+ }
+ free(indent);
+ fprintf(fp, "\n");
+ return 0;
+}
+
+static int
+PrintIntArray(Cmd_T *cmd,
+ int TypeFlag,
+ int ValFlag,
+ FILE *fp)
+{
+ char *indent;
+ int l = 4+strlen(cmd->Name),
+ n,
+ *i = *(int**)cmd->Val;
+
+ fprintf(fp, "%s", cmd->Name);
+ if(TypeFlag) {
+ fprintf(fp, ": int array, separator \"%s\"",
+ *(char**)cmd->p?*(char**)cmd->p:"");
+ }
+ n = *((int**)cmd->p)[1];
+ indent = malloc(l+2);
+ memset(indent, ' ', l+1);
+ indent[l+1] = 0;
+ if(ValFlag) {
+ fprintf(fp, ":");
+ if(i&&n>0) {
+ fprintf(fp, " %d", *i++);
+ while(--n) fprintf(fp, "\n%s %d", indent, *i++);
+ }
+ }
+ free(indent);
+ fprintf(fp, "\n");
+ return 0;
+}
+
+static int
+PrintDblArray(Cmd_T *cmd,
+ int TypeFlag,
+ int ValFlag,
+ FILE *fp)
+{
+ char *indent;
+ int l = 4+strlen(cmd->Name),
+ n;
+ double *x = *(double**)cmd->Val;
+
+ fprintf(fp, "%s", cmd->Name);
+ if(TypeFlag) {
+ fprintf(fp, ": double array, separator \"%s\"",
+ *(char**)cmd->p?*(char**)cmd->p:"");
+ }
+ n = *((int**)cmd->p)[1];
+ indent = malloc(l+2);
+ memset(indent, ' ', l+1);
+ indent[l+1] = 0;
+ if(ValFlag) {
+ fprintf(fp, ":");
+ if(x&&n>0) {
+ fprintf(fp, " %e", *x++);
+ while(--n) fprintf(fp, "\n%s %e", indent, *x++);
+ }
+ }
+ free(indent);
+ fprintf(fp, "\n");
+ return 0;
+}
+
+static char **
+str2array(char *s,
+ char *sep)
+{
+ char *p, **a;
+ int n = 0;
+
+ if(!sep) sep = SepString;
+ p = s += strspn(s, sep);
+ if(!*p) return 0;
+ while(*p) {
+ p += strcspn(p, sep);
+ p += strspn(p, sep);
+ ++n;
+ }
+ a = calloc(n+1, sizeof(char*));
+ p = s;
+ n = 0;
+ while(*p) {
+ int l = strcspn(p, sep);
+ a[n] = malloc(l+1);
+ memcpy(a[n], p, l);
+ a[n][l] = 0;
+ ++n;
+ p += l;
+ p += strspn(p, sep);
+ }
+ return a;
+}
+
+int
+str2narray(int type,
+ char *s,
+ char *sep,
+ void **a)
+{
+ char *p;
+ double *x;
+ int *i;
+ int n = 0;
+
+ if(!sep) sep=SepString;
+ for(p=s; *p; ) {
+ p += strcspn(p, sep);
+ p += !!*p;
+ ++n;
+ }
+ *a = 0;
+ if(!n) return 0;
+ *a = calloc(n, (type==CMDINTARRAYTYPE)?sizeof(int):sizeof(double));
+ i = (int*)*a;
+ x = (double*)*a;
+ p = s;
+ n = 0;
+ while(*p) {
+ switch(type) {
+ case CMDINTARRAYTYPE:
+ *i++ = atoi(p);
+ break;
+ case CMDDBLARRAYTYPE:
+ *x++ = atof(p);
+ break;
+ }
+ ++n;
+ p += strcspn(p, sep);
+ p += !!*p;
+ }
+ return n;
+}
+
+static int
+StoreCmdLine(char *s)
+{
+ s += strspn(s, SepString);
+ if(!*s) return 0;
+ if(CmdLinesL>=CmdLinesSz) {
+ CmdLines=CmdLinesSz
+ ? (char**)realloc(CmdLines,
+ (CmdLinesSz+=BUFSIZ)*sizeof(char**))
+ : (char**)malloc((CmdLinesSz=BUFSIZ)*sizeof(char**));
+ if(!CmdLines) {
+ fprintf(stderr, "%s\n",
+ "StoreCmdLine(): malloc() failed");
+ exit(IRSTLM_CMD_ERROR_MEMORY);
+ }
+ }
+ CmdLines[CmdLinesL++] = strdup(s);
+ return 0;
+}
+
diff --git a/src/cmd.h b/src/cmd.h
new file mode 100644
index 0000000..df0f3de
--- /dev/null
+++ b/src/cmd.h
@@ -0,0 +1,119 @@
+// $Id: cmd.h 3626 2010-10-07 11:41:05Z bertoldi $
+
+/******************************************************************************
+ IrstLM: IRST Language Model Toolkit
+ Copyright (C) 2006 Marcello Federico, ITC-irst Trento, Italy
+
+ This library is free software; you can redistribute it and/or
+ modify it under the terms of the GNU Lesser General Public
+ License as published by the Free Software Foundation; either
+ version 2.1 of the License, or (at your option) any later version.
+
+ This library is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+ Lesser General Public License for more details.
+
+ You should have received a copy of the GNU Lesser General Public
+ License along with this library; if not, write to the Free Software
+ Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+
+ ******************************************************************************/
+
+#if !defined(CMD_H)
+
+#define CMD_H
+
+
+
+#define FALSE 0
+#define TRUE 1
+
+#define END_ENUM { (char*)0, 0 }
+
+
+
+#define IRSTLM_CMD_NO_ERROR 0
+#define IRSTLM_CMD_ERROR_GENERIC 1
+#define IRSTLM_CMD_ERROR_IO 2
+#define IRSTLM_CMD_ERROR_MEMORY 3
+#define IRSTLM_CMD_ERROR_DATA 4
+#define IRSTLM_CMD_ERROR_MODEL 5
+
+
+#define CMDDOUBLETYPE 1
+#define CMDENUMTYPE 2
+#define CMDINTTYPE 3
+#define CMDSTRINGTYPE 4
+#define CMDSUBRANGETYPE 5
+#define CMDGTETYPE 6
+#define CMDLTETYPE 7
+#define CMDSTRARRAYTYPE 8
+#define CMDBOOLTYPE 9
+#define CMDBOOLTYPE2 19
+#define CMDFLAGTYPE 10
+#define CMDINTARRAYTYPE 11
+#define CMDDBLARRAYTYPE 12
+#define CMDFLOATTYPE 13
+
+#define CMDMSG (1<<31)
+
+#include <stdio.h>
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+ typedef struct {
+ char *Name;
+ int Idx;
+ } Enum_T;
+
+ typedef struct {
+ char *Name;
+ char Idx;
+ } Bool_T;
+
+ typedef struct {
+ int Type;
+ int Flag;
+ void *Val;
+ void *p;
+ char *Name;
+ char *ArgStr;
+ char *Msg;
+ } Cmd_T;
+
+ int
+ DeclareParams(char *,
+ ...),
+ GetParams(int *n,
+ char ***a,
+ char *CmdFileName),
+ GetDotParams(char *,
+ ...),
+ SPrintParams(char ***a,
+ char *pfx),
+ PrintParams(int ValFlag,
+ FILE *fp),
+ FullPrintParams(int TypeFlag,
+ int ValFlag,
+ int MsgFlag,
+ FILE *fp),
+ EnumIdx(Enum_T *en,
+ char *s);
+ char BoolIdx(Bool_T *en,
+ char *s);
+ char
+ *EnumStr(Enum_T *en,
+ int i);
+ char
+ *BoolStr(Bool_T *en,
+ int i);
+
+#ifdef __cplusplus
+}
+#endif
+
+#endif
+
diff --git a/src/compile-lm.cpp b/src/compile-lm.cpp
new file mode 100644
index 0000000..f3c534b
--- /dev/null
+++ b/src/compile-lm.cpp
@@ -0,0 +1,526 @@
+// $Id: compile-lm.cpp 3677 2010-10-13 09:06:51Z bertoldi $
+
+/******************************************************************************
+ IrstLM: IRST Language Model Toolkit, compile LM
+ Copyright (C) 2006 Marcello Federico, ITC-irst Trento, Italy
+
+ This library is free software; you can redistribute it and/or
+ modify it under the terms of the GNU Lesser General Public
+ License as published by the Free Software Foundation; either
+ version 2.1 of the License, or (at your option) any later version.
+
+ This library is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+ Lesser General Public License for more details.
+
+ You should have received a copy of the GNU Lesser General Public
+ License along with this library; if not, write to the Free Software
+ Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+
+ ******************************************************************************/
+
+
+#include <iostream>
+#include <fstream>
+#include <vector>
+#include <string>
+#include <stdlib.h>
+#include "cmd.h"
+#include "util.h"
+#include "math.h"
+#include "lmContainer.h"
+
+using namespace std;
+using namespace irstlm;
+
+/********************************/
+void print_help(int TypeFlag=0){
+ std::cerr << std::endl << "compile-lm - compiles an ARPA format LM into an IRSTLM format one" << std::endl;
+ std::cerr << std::endl << "USAGE:" << std::endl;
+ std::cerr << " compile-lm [options] <input-file.lm> [output-file.blm]" << std::endl;
+ std::cerr << std::endl << "DESCRIPTION:" << std::endl;
+ std::cerr << " compile-lm reads a standard LM file in ARPA format and produces" << std::endl;
+ std::cerr << " a compiled representation that the IRST LM toolkit can quickly" << std::endl;
+ std::cerr << " read and process. LM file can be compressed." << std::endl;
+ std::cerr << std::endl << "OPTIONS:" << std::endl;
+
+ FullPrintParams(TypeFlag, 0, 1, stderr);
+}
+
+void usage(const char *msg = 0)
+{
+ if (msg) {
+ std::cerr << msg << std::endl;
+ }
+ if (!msg){
+ print_help();
+ }
+}
+
+int main(int argc, char **argv)
+{
+ char *seval=NULL;
+ char *tmpdir=NULL;
+ char *sfilter=NULL;
+
+ bool textoutput = false;
+ bool sent_PP_flag = false;
+ bool invert = false;
+ bool sscore = false;
+ bool ngramscore = false;
+ bool skeepunigrams = false;
+
+ int debug = 0;
+ bool memmap = false;
+ int requiredMaxlev = 1000;
+ int dub = 10000000;
+ int randcalls = 0;
+ float ngramcache_load_factor = 0.0;
+ float dictionary_load_factor = 0.0;
+
+ bool help=false;
+ std::vector<std::string> files;
+
+ DeclareParams((char*)
+ "text", CMDBOOLTYPE|CMDMSG, &textoutput, "output is again in text format; default is false",
+ "t", CMDBOOLTYPE|CMDMSG, &textoutput, "output is again in text format; default is false",
+ "filter", CMDSTRINGTYPE|CMDMSG, &sfilter, "filter a binary language model with a word list",
+ "f", CMDSTRINGTYPE|CMDMSG, &sfilter, "filter a binary language model with a word list",
+ "keepunigrams", CMDBOOLTYPE|CMDMSG, &skeepunigrams, "filter by keeping all unigrams in the table, default is true",
+ "ku", CMDBOOLTYPE|CMDMSG, &skeepunigrams, "filter by keeping all unigrams in the table, default is true",
+ "eval", CMDSTRINGTYPE|CMDMSG, &seval, "computes perplexity of the specified text file",
+ "e", CMDSTRINGTYPE|CMDMSG, &seval, "computes perplexity of the specified text file",
+ "randcalls", CMDINTTYPE|CMDMSG, &randcalls, "computes N random calls on the specified text file",
+ "r", CMDINTTYPE|CMDMSG, &randcalls, "computes N random calls on the specified text file",
+ "score", CMDBOOLTYPE|CMDMSG, &sscore, "computes log-prob scores of n-grams from standard input",
+ "s", CMDBOOLTYPE|CMDMSG, &sscore, "computes log-prob scores of n-grams from standard input",
+ "ngramscore", CMDBOOLTYPE|CMDMSG, &ngramscore, "computes log-prob scores of the last n-gram before an _END_NGRAM_ symbol from standard input",
+ "ns", CMDBOOLTYPE|CMDMSG, &ngramscore, "computes log-prob scores of the last n-gram before an _END_NGRAM_ symbol from standard input",
+ "debug", CMDINTTYPE|CMDMSG, &debug, "verbose output for --eval option; default is 0",
+ "d", CMDINTTYPE|CMDMSG, &debug, "verbose output for --eval option; default is 0",
+ "level", CMDINTTYPE|CMDMSG, &requiredMaxlev, "maximum level to load from the LM; if value is larger than the actual LM order, the latter is taken",
+ "l", CMDINTTYPE|CMDMSG, &requiredMaxlev, "maximum level to load from the LM; if value is larger than the actual LM order, the latter is taken",
+ "memmap", CMDBOOLTYPE|CMDMSG, &memmap, "uses memory map to read a binary LM",
+ "mm", CMDBOOLTYPE|CMDMSG, &memmap, "uses memory map to read a binary LM",
+ "dub", CMDINTTYPE|CMDMSG, &dub, "dictionary upperbound to compute OOV word penalty: default 10^7",
+ "tmpdir", CMDSTRINGTYPE|CMDMSG, &tmpdir, "directory for temporary computation, default is either the environment variable TMP if defined or \"/tmp\")",
+ "invert", CMDBOOLTYPE|CMDMSG, &invert, "builds an inverted n-gram binary table for fast access; default if false",
+ "i", CMDBOOLTYPE|CMDMSG, &invert, "builds an inverted n-gram binary table for fast access; default if false",
+ "sentence", CMDBOOLTYPE|CMDMSG, &sent_PP_flag, "computes perplexity at sentence level (identified through the end symbol)",
+ "dict_load_factor", CMDFLOATTYPE|CMDMSG, &dictionary_load_factor, "sets the load factor for ngram cache; it should be a positive real value; default is 0",
+ "ngram_load_factor", CMDFLOATTYPE|CMDMSG, &ngramcache_load_factor, "sets the load factor for ngram cache; it should be a positive real value; default is false",
+
+ "Help", CMDBOOLTYPE|CMDMSG, &help, "print this help",
+ "h", CMDBOOLTYPE|CMDMSG, &help, "print this help",
+
+ (char*)NULL
+ );
+
+ if (argc == 1){
+ usage();
+ exit_error(IRSTLM_NO_ERROR);
+ }
+
+ for(int i=1; i < argc; i++) {
+ if(argv[i][0] != '-'){
+ files.push_back(argv[i]);
+ }
+ }
+
+
+ GetParams(&argc, &argv, (char*) NULL);
+
+ if (help){
+ usage();
+ exit_error(IRSTLM_NO_ERROR);
+ }
+
+ if (files.size() > 2) {
+ usage();
+ exit_error(IRSTLM_ERROR_DATA,"Warning: Too many arguments");
+ }
+
+ if (files.size() < 1) {
+ usage();
+ exit_error(IRSTLM_ERROR_DATA,"Warning: Please specify a LM file to read from");
+ }
+
+ std::string infile = files[0];
+ std::string outfile = "";
+
+ if (files.size() == 1) {
+ outfile=infile;
+
+ //remove path information
+ std::string::size_type p = outfile.rfind('/');
+ if (p != std::string::npos && ((p+1) < outfile.size()))
+ outfile.erase(0,p+1);
+
+ //eventually strip .gz
+ if (outfile.compare(outfile.size()-3,3,".gz")==0)
+ outfile.erase(outfile.size()-3,3);
+
+ outfile+=(textoutput?".lm":".blm");
+ } else{
+ outfile = files[1];
+ }
+
+ std::cerr << "inpfile: " << infile << std::endl;
+ std::cerr << "outfile: " << outfile << std::endl;
+ if (seval!=NULL) std::cerr << "evalfile: " << seval << std::endl;
+ if (sscore==true) std::cerr << "interactive: " << sscore << std::endl;
+ if (ngramscore==true) std::cerr << "interactive for ngrams only: " << ngramscore << std::endl;
+ if (memmap) std::cerr << "memory mapping: " << memmap << std::endl;
+ std::cerr << "loading up to the LM level " << requiredMaxlev << " (if any)" << std::endl;
+ std::cerr << "dub: " << dub<< std::endl;
+ if (tmpdir != NULL) {
+ if (setenv("TMP",tmpdir,1))
+ std::cerr << "temporary directory has not been set" << std::endl;
+ std::cerr << "tmpdir: " << tmpdir << std::endl;
+ }
+
+
+ //checking the language model type
+ lmContainer* lmt = lmContainer::CreateLanguageModel(infile,ngramcache_load_factor,dictionary_load_factor);
+
+ //let know that table has inverted n-grams
+ if (invert) lmt->is_inverted(invert);
+
+ lmt->setMaxLoadedLevel(requiredMaxlev);
+
+ lmt->load(infile);
+
+ //CHECK this part for sfilter to make it possible only for LMTABLE
+ if (sfilter != NULL) {
+ lmContainer* filtered_lmt = NULL;
+ std::cerr << "BEFORE sublmC (" << (void*) filtered_lmt << ") (" << (void*) &filtered_lmt << ")\n";
+
+ // the function filter performs the filtering and returns true, only for specific lm type
+ if (((lmContainer*) lmt)->filter(sfilter,filtered_lmt,skeepunigrams?"yes":"no")) {
+ std::cerr << "BFR filtered_lmt (" << (void*) filtered_lmt << ") (" << (void*) &filtered_lmt << ")\n";
+ filtered_lmt->stat();
+ delete lmt;
+ lmt=filtered_lmt;
+ std::cerr << "AFTER filtered_lmt (" << (void*) filtered_lmt << ")\n";
+ filtered_lmt->stat();
+ std::cerr << "AFTER lmt (" << (void*) lmt << ")\n";
+ lmt->stat();
+ }
+ }
+
+ if (dub) lmt->setlogOOVpenalty((int)dub);
+
+ //use caches to save time (only if PS_CACHE_ENABLE is defined through compilation flags)
+ lmt->init_caches(lmt->maxlevel());
+
+ if (seval != NULL) {
+ if (randcalls>0) {
+
+ cerr << "perform random " << randcalls << " using dictionary of test set\n";
+ dictionary *dict;
+ dict=new dictionary(seval);
+
+ //build extensive histogram
+ int histo[dict->totfreq()]; //total frequency
+ int totfreq=0;
+
+ for (int n=0; n<dict->size(); n++)
+ for (int m=0; m<dict->freq(n); m++)
+ histo[totfreq++]=n;
+
+ ngram ng(lmt->getDict());
+ srand(1234);
+ double bow;
+ int bol=0;
+
+ if (debug>1) ResetUserTime();
+
+ for (int n=0; n<randcalls; n++) {
+ //extracts a random word from dict
+ int w=histo[rand() % totfreq];
+
+ ng.pushc(lmt->getDict()->encode(dict->decode(w)));
+
+ lmt->clprob(ng,&bow,&bol); //(using caches if available)
+
+ if (debug==1) {
+ std::cout << ng.dict->decode(*ng.wordp(1)) << " [" << lmt->maxlevel()-bol << "]" << " ";
+ std::cout << std::endl;
+ std::cout.flush();
+ }
+
+ if ((n % 100000)==0) {
+ std::cerr << ".";
+ lmt->check_caches_levels();
+ }
+ }
+ std::cerr << "\n";
+ if (debug>1) PrintUserTime("Finished in");
+ if (debug>1) lmt->stat();
+
+ delete lmt;
+ return 0;
+
+ } else {
+ if (lmt->getLanguageModelType() == _IRSTLM_LMINTERPOLATION) {
+ debug = (debug>4)?4:debug;
+ std::cerr << "Maximum debug value for this LM type: " << debug << std::endl;
+ }
+ if (lmt->getLanguageModelType() == _IRSTLM_LMMACRO) {
+ debug = (debug>4)?4:debug;
+ std::cerr << "Maximum debug value for this LM type: " << debug << std::endl;
+ }
+ if (lmt->getLanguageModelType() == _IRSTLM_LMCLASS) {
+ debug = (debug>4)?4:debug;
+ std::cerr << "Maximum debug value for this LM type: " << debug << std::endl;
+ }
+ std::cerr << "Start Eval" << std::endl;
+ std::cerr << "OOV code: " << lmt->getDict()->oovcode() << std::endl;
+ ngram ng(lmt->getDict());
+ std::cout.setf(ios::fixed);
+ std::cout.precision(2);
+
+ // if (debug>0) std::cout.precision(8);
+ std::fstream inptxt(seval,std::ios::in);
+
+ int Nbo=0, Nw=0,Noov=0;
+ double logPr=0,PP=0,PPwp=0,Pr;
+
+ // variables for storing sentence-based Perplexity
+ int sent_Nbo=0, sent_Nw=0,sent_Noov=0;
+ double sent_logPr=0,sent_PP=0,sent_PPwp=0;
+
+
+ ng.dict->incflag(1);
+ int bos=ng.dict->encode(ng.dict->BoS());
+ int eos=ng.dict->encode(ng.dict->EoS());
+ ng.dict->incflag(0);
+
+ double bow;
+ int bol=0;
+ char *msp;
+ unsigned int statesize;
+
+ lmt->dictionary_incflag(1);
+
+ while(inptxt >> ng) {
+
+ if (ng.size>lmt->maxlevel()) ng.size=lmt->maxlevel();
+
+ // reset ngram at begin of sentence
+ if (*ng.wordp(1)==bos) {
+ ng.size=1;
+ continue;
+ }
+
+ if (ng.size>=1) {
+ Pr=lmt->clprob(ng,&bow,&bol,&msp,&statesize);
+ logPr+=Pr;
+ sent_logPr+=Pr;
+
+ if (debug==1) {
+ std::cout << ng.dict->decode(*ng.wordp(1)) << " [" << ng.size-bol << "]" << " ";
+ if (*ng.wordp(1)==eos) std::cout << std::endl;
+ }
+ else if (debug==2) {
+ std::cout << ng << " [" << ng.size-bol << "-gram]" << " " << Pr;
+ std::cout << std::endl;
+ std::cout.flush();
+ }
+ else if (debug==3) {
+ std::cout << ng << " [" << ng.size-bol << "-gram]" << " " << Pr << " bow:" << bow;
+ std::cout << std::endl;
+ std::cout.flush();
+ }
+ else if (debug==4) {
+ std::cout << ng << " [" << ng.size-bol << "-gram: recombine:" << statesize << " state:" << (void*) msp << "] [" << ng.size+1-((bol==0)?(1):bol) << "-gram: bol:" << bol << "] " << Pr << " bow:" << bow;
+ std::cout << std::endl;
+ std::cout.flush();
+ }
+ else if (debug>4) {
+ std::cout << ng << " [" << ng.size-bol << "-gram: recombine:" << statesize << " state:" << (void*) msp << "] [" << ng.size+1-((bol==0)?(1):bol) << "-gram: bol:" << bol << "] " << Pr << " bow:" << bow;
+ double totp=0.0;
+ int oldw=*ng.wordp(1);
+ double oovp=lmt->getlogOOVpenalty();
+ lmt->setlogOOVpenalty((double) 0);
+ for (int c=0; c<ng.dict->size(); c++) {
+ *ng.wordp(1)=c;
+ totp+=pow(10.0,lmt->clprob(ng)); //using caches if available
+ }
+ *ng.wordp(1)=oldw;
+
+ if ( totp < (1.0 - 1e-5) || totp > (1.0 + 1e-5))
+ std::cout << " [t=" << totp << "] POSSIBLE ERROR";
+ std::cout << std::endl;
+ std::cout.flush();
+
+ lmt->setlogOOVpenalty((double)oovp);
+ }
+
+
+ if (lmt->is_OOV(*ng.wordp(1))) {
+ Noov++;
+ sent_Noov++;
+ }
+ if (bol) {
+ Nbo++;
+ sent_Nbo++;
+ }
+ Nw++;
+ sent_Nw++;
+ if (sent_PP_flag && (*ng.wordp(1)==eos)) {
+ sent_PP=exp((-sent_logPr * log(10.0)) /sent_Nw);
+ sent_PPwp= sent_PP * (1 - 1/exp((sent_Noov * lmt->getlogOOVpenalty()) * log(10.0) / sent_Nw));
+
+ std::cout << "%% sent_Nw=" << sent_Nw
+ << " sent_PP=" << sent_PP
+ << " sent_PPwp=" << sent_PPwp
+ << " sent_Nbo=" << sent_Nbo
+ << " sent_Noov=" << sent_Noov
+ << " sent_OOV=" << (float)sent_Noov/sent_Nw * 100.0 << "%" << std::endl;
+ std::cout.flush();
+ //reset statistics for sentence based Perplexity
+ sent_Nw=sent_Noov=sent_Nbo=0;
+ sent_logPr=0.0;
+ }
+
+ if ((Nw % 100000)==0) {
+ std::cerr << ".";
+ lmt->check_caches_levels();
+ }
+
+ }
+ }
+
+ PP=exp((-logPr * log(10.0)) /Nw);
+
+ PPwp= PP * (1 - 1/exp((Noov * lmt->getlogOOVpenalty()) * log(10.0) / Nw));
+
+ std::cout << "%% Nw=" << Nw
+ << " PP=" << PP
+ << " PPwp=" << PPwp
+ << " Nbo=" << Nbo
+ << " Noov=" << Noov
+ << " OOV=" << (float)Noov/Nw * 100.0 << "%";
+ if (debug) std::cout << " logPr=" << logPr;
+ std::cout << std::endl;
+ std::cout.flush();
+
+ if (debug>1) lmt->used_caches();
+
+ if (debug>1) lmt->stat();
+
+ delete lmt;
+ return 0;
+ };
+ }
+
+ if (sscore == true) {
+
+ ngram ng(lmt->getDict());
+ int bos=ng.dict->encode(ng.dict->BoS());
+
+ int bol;
+ double bow;
+ unsigned int n=0;
+
+ std::cout.setf(ios::scientific);
+ std::cout.setf(ios::fixed);
+ std::cout.precision(2);
+ std::cout << "> ";
+
+ lmt->dictionary_incflag(1);
+
+ while(std::cin >> ng) {
+
+ //std::cout << ng << std::endl;;
+ // reset ngram at begin of sentence
+ if (*ng.wordp(1)==bos) {
+ ng.size=1;
+ continue;
+ }
+
+ if (ng.size>=lmt->maxlevel()) {
+ ng.size=lmt->maxlevel();
+ ++n;
+ if ((n % 100000)==0) {
+ std::cerr << ".";
+ lmt->check_caches_levels();
+ }
+ std::cout << ng << " p= " << lmt->clprob(ng,&bow,&bol) * M_LN10;
+ std::cout << " bo= " << bol << std::endl;
+ } else {
+ std::cout << ng << " p= NULL" << std::endl;
+ }
+ std::cout << "> ";
+ }
+ std::cout << std::endl;
+ std::cout.flush();
+ if (debug>1) lmt->used_caches();
+
+ if (debug>1) lmt->stat();
+
+ delete lmt;
+ return 0;
+ }
+
+
+ if (ngramscore == true) {
+
+ const char* _END_NGRAM_="_END_NGRAM_";
+ ngram ng(lmt->getDict());
+
+ double Pr;
+ double bow;
+ int bol=0;
+ char *msp;
+ unsigned int statesize;
+
+ std::cout.setf(ios::fixed);
+ std::cout.precision(2);
+
+ ng.dict->incflag(1);
+ int endngram=ng.dict->encode(_END_NGRAM_);
+ ng.dict->incflag(0);
+
+ while(std::cin >> ng) {
+ // compute score for the last ngram when endngram symbols is found
+ // and reset ngram
+ if (*ng.wordp(1)==endngram) {
+ ng.shift();
+ if (ng.size>=lmt->maxlevel()) {
+ ng.size=lmt->maxlevel();
+ }
+
+ Pr=lmt->clprob(ng,&bow,&bol,&msp,&statesize);
+#ifndef OUTPUT_SUPPRESSED
+ std::cout << ng << " [" << ng.size-bol << "-gram: recombine:" << statesize << " state:" << (void*) msp << "] [" << ng.size+1-((bol==0)?(1):bol) << "-gram: bol:" << bol << "] " << Pr << " bow:" << bow;
+ std::cout << std::endl;
+ std::cout.flush();
+#endif
+ ng.size=0;
+ }
+ }
+
+ if (debug>1) lmt->used_caches();
+
+ if (debug>1) lmt->stat();
+
+ delete lmt;
+ return 0;
+ }
+
+ if (textoutput == true) {
+ std::cerr << "Saving in txt format to " << outfile << std::endl;
+ lmt->savetxt(outfile.c_str());
+ } else if (!memmap) {
+ std::cerr << "Saving in bin format to " << outfile << std::endl;
+ lmt->savebin(outfile.c_str());
+ } else {
+ std::cerr << "Impossible to save to " << outfile << std::endl;
+ }
+ delete lmt;
+ return 0;
+}
+
diff --git a/src/cplsa.cpp b/src/cplsa.cpp
new file mode 100755
index 0000000..af9460c
--- /dev/null
+++ b/src/cplsa.cpp
@@ -0,0 +1,641 @@
+/******************************************************************************
+ IrstLM: IRST Language Model Toolkit, compile LM
+ Copyright (C) 2006 Marcello Federico, ITC-irst Trento, Italy
+
+ This library is free software; you can redistribute it and/or
+ modify it under the terms of the GNU Lesser General Public
+ License as published by the Free Software Foundation; either
+ version 2.1 of the License, or (at your option) any later version.
+
+ This library is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+ Lesser General Public License for more details.
+
+ You should have received a copy of the GNU Lesser General Public
+ License along with this library; if not, write to the Free Software
+ Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+
+ **********************************************dou********************************/
+
+#include <sys/mman.h>
+#include <stdio.h>
+#include <cmath>
+#include <string>
+#include <sstream>
+#include <pthread.h>
+#include "thpool.h"
+#include "mfstream.h"
+#include "mempool.h"
+#include "htable.h"
+#include "n_gram.h"
+#include "util.h"
+#include "dictionary.h"
+#include "ngramtable.h"
+#include "doc.h"
+#include "cplsa.h"
+
+using namespace std;
+
+namespace irstlm {
+
+plsa::plsa(dictionary* d,int top,char* wd,int th,bool mm){
+
+ dict=d;
+
+ topics=top;
+
+ tmpdir=wd;
+
+ memorymap=mm;
+
+ threads=th;
+
+ MY_ASSERT (topics>0);
+
+ //actual model structure
+ W=NULL;
+
+ //training support structure
+ T=NULL;
+
+ //allocate/free at training time// this is the huge table
+ H=NULL;
+
+
+ srandom(100); //consistent generation of random noise
+
+ bucket=BUCKET;
+
+ maxiter=0;
+}
+
+plsa::~plsa() {
+ freeW();
+ freeH();
+ free(T);
+}
+
+int plsa::initW(char* modelfile,float noise,int spectopic){
+
+ //W needs a dictionary, either from an existing model or
+ //from the training data
+
+ assert(W==NULL);
+
+ if (dict==NULL) loadW(modelfile);
+ else{
+ cerr << "Allocating W table\n";
+ W=new float* [dict->size()];
+ for (int i=0; i<dict->size(); i++){
+ W[i]=new float [topics](); //initialized to zero since C++11
+ memset(W[i],0,sizeof(float)*topics);
+ }
+ cerr << "Initializing W table\n";
+ if (spectopic) {
+ //special topic 0: first st most frequent
+ //assume dictionary is sorted by frequency!!!
+ float TotW=0;
+ for (int i=0; i<spectopic; i++)
+ TotW+=W[i][0]=dict->freq(i);
+ for (int i=0; i<spectopic; i++)
+ W[i][0]/=TotW;
+ }
+
+ for (int t=(spectopic?1:0); t<topics; t++) {
+ float TotW=0;
+ for (int i=spectopic; i< dict->size(); i++)
+ TotW+=W[i][t]=1 + noise * MY_RAND;
+ for (int i=spectopic; i< dict->size(); i++)
+ W[i][t]/=TotW;
+ }
+ }
+ return 1;
+}
+
+int plsa::freeW(){
+ if (W!=NULL){
+ cerr << "Releasing memory of W table\n";
+ for (int i=0; i<dict->size(); i++) delete [] W[i];
+ delete [] W;
+ W=NULL;
+ }
+ return 1;
+}
+
+
+
+int plsa::initH(){
+
+ assert(trset->numdoc()); //need a date set
+ long long len=(unsigned long long)trset->numdoc() * topics;
+
+ FILE *fd;
+ if (H == NULL){
+ if (memorymap){
+ cerr << "Creating memory mapped H table\n";
+ //generate a name for the memory map file
+ sprintf(Hfname,"/%s/Hfname%d",tmpdir,(int)getpid());
+ if ((fd=fopen(Hfname,"w+"))==0){
+ perror("Could not create file");
+ exit_error(IRSTLM_ERROR_IO, "plsa::initH fopen error");
+ }
+ //H is aligned at integer
+ ftruncate(fileno(fd),len * sizeof(float));
+ H = (float *)mmap( 0, len * sizeof(float) , PROT_READ|PROT_WRITE, MAP_PRIVATE,fileno(fd),0);
+ fclose(fd);
+ if (H == MAP_FAILED){
+ perror("Mmap error");
+ exit_error(IRSTLM_ERROR_IO, "plsa::initH MMAP error");
+ }
+ }
+ else{
+ cerr << "Allocating " << len << " entries for H table\n";
+ fprintf(stderr,"%llu\n",len);
+ if ((H=new float[len])==NULL){
+ perror("memory allocation error");
+ exit_error(IRSTLM_ERROR_IO, "plsa::cannot allocate memory for H");
+ }
+ }
+ }
+ cerr << "Initializing H table " << "\n";
+ float value=1/(float)topics;
+ for (long long d=0; d< trset->numdoc(); d++)
+ for (int t=0; t<topics; t++)
+ H[d*topics+t]=value;
+ cerr << "done\n";
+ return 1;
+}
+
+int plsa::freeH(){
+ if (H!=NULL){
+ cerr << "Releasing memory for H table\n";
+ if (memorymap){
+ munmap((void *)H,trset->numdoc()*topics*sizeof(float));
+ remove(Hfname);
+ }else
+ delete [] H;
+
+ H=NULL;
+
+ }
+ return 1;
+}
+
+
+int plsa::initT(){ //keep double for counts collected over the whole training data
+ if (T==NULL){
+ T=new double* [dict->size()];
+ for (int i=0; i<dict->size(); i++)
+ T[i]=new double [topics];
+ }
+ for (int i=0; i<dict->size(); i++)
+ memset((void *)T[i],0,topics * sizeof(double));
+
+ return 1;
+}
+
+int plsa::freeT(){
+ if (T!=NULL){
+ cerr << "Releasing memory for T table\n";
+ for (int i=0; i<dict->size(); i++) delete [] T[i];
+ delete [] T;
+ T=NULL;
+ }
+ return 1;
+}
+
+
+/*
+int plsa::saveWtxt2(char* fname){
+ cerr << "Writing text W table into: " << fname << "\n";
+ mfstream out(fname,ios::out);
+ out.precision(5);
+// out << topics << "\n";
+ for (int i=0; i<dict->size(); i++) {
+ out << dict->decode(i);// << " " << dict->freq(i);
+ //double totW=0;
+ //for (int t=0; t<topics; t++) totW+=W[i][t];
+ //out <<" totPr: " << totW << " :";
+ for (int t=0; t<topics; t++)
+ out << " " << W[i][t];
+ out << "\n";
+ }
+ out.close();
+ return 1;
+}
+*/
+
+typedef struct {
+ int word;
+ float score;
+} mypairtype;
+
+int comparepair (const void * a, const void * b){
+ if ( (*(mypairtype *)a).score < (*(mypairtype *)b).score ) return 1;
+ if ( (*(mypairtype *)a).score > (*(mypairtype *)b).score ) return -1;
+ return 0;
+}
+
+int plsa::saveWtxt(char* fname,int tw){
+ cerr << "Writing model W into: " << fname << "\n";
+ mfstream out(fname,ios::out);
+ out.precision(5);
+
+ mypairtype *vect=new mypairtype[dict->size()];
+
+ // out << topics << "\n";
+ for (int t=0; t<topics; t++){
+
+ for (int i=0; i<dict->size(); i++){
+ vect[i].word=i;
+ vect[i].score=W[i][t];
+ }
+ vect[dict->oovcode()].score=0;
+ qsort((void *)vect,dict->size(),sizeof(mypairtype),comparepair);
+
+ out << "T" << t;
+ for (int i=0;i<tw;i++){
+ out << " " << dict->decode(vect[i].word);// << " " << vect[i].score << " ";
+
+ }
+ out << "\n";
+ }
+ delete [] vect;
+ out.close();
+ return 1;
+}
+
+int plsa::saveW(char* fname){
+ cerr << "Saving model into: " << fname << " ...";
+ mfstream out(fname,ios::out);
+ out << "PLSA " << topics << "\n";
+ dict->save(out);
+ for (int i=0; i<dict->size(); i++)
+ out.write((const char*)W[i],sizeof(float) * topics);
+ out.close();
+ cerr << "\n";
+ return 1;
+}
+
+int plsa::loadW(char* fname){
+ assert(dict==NULL);
+ cerr << "Loading model from: " << fname << "\n";
+ mfstream inp(fname,ios::in);
+ char header[100];
+ inp.getline(header,100);
+ cerr << header ;
+ int r;
+ sscanf(header,"PLSA %d\n",&r);
+ if (topics>0 && r != topics)
+ exit_error(IRSTLM_ERROR_DATA, "incompatible number of topics");
+ else
+ topics=r;
+
+ cerr << "Loading dictionary\n";
+ dict=new dictionary(NULL,1000000);
+ dict->load(inp);
+ dict->encode(dict->OOV());
+ cerr << "Allocating W table\n";
+ W=new float* [dict->size()];
+ for (int i=0; i<dict->size(); i++)
+ W[i]=new float [topics];
+
+ cerr << "Reading W table .... ";
+ for (int i=0; i<dict->size(); i++)
+ inp.read((char *)W[i],sizeof(float) * topics);
+
+ inp.close();
+ cerr << "\n";
+ return 1;
+}
+
+int plsa::saveWordFeatures(char* fname,long long d){
+
+ //extend this to save features for all adapation documents
+ //compute distribution on doc 0
+ assert(trset !=NULL);
+
+ if (d<100){
+
+ double *WH=new double [dict->size()];
+ char *outfname=new char[strlen(fname)+10];
+
+ sprintf(outfname,"%s.%03d",fname,(int)d+1);
+ cerr << "Saving word features in " << fname << "\n";
+
+ for (int i=0; i<dict->size(); i++) {
+ WH[i]=0;
+ for (int t=0; t<topics; t++)
+ WH[i]+=W[i][t]*H[(d % bucket) * topics + t];
+ }
+
+ double maxp=WH[0];
+ for (int i=1; i<dict->size(); i++)
+ if (WH[i]>maxp) maxp=WH[i];
+
+ cerr << "Get max prob" << maxp << "\n";
+
+ //save unigrams in google ngram format
+ mfstream out(outfname,ios::out);
+ for (int i=0; i<dict->size(); i++){
+ int freq=(int)floor((WH[i]/maxp) * 1000000);
+ if (freq)
+ out << dict->decode(i) <<" \t" << freq<<"\n";
+
+ }
+ out.close();
+
+ delete [] outfname;
+ delete [] WH;
+
+ }
+ return 1;
+}
+
+///*****
+pthread_mutex_t cplsa_mut1;
+pthread_mutex_t cplsa_mut2;
+double cplsa_LL=0; //Log likelihood
+const float topicthreshold=0.00001;
+const float deltathreshold=0.0001;
+
+
+void plsa::expected_counts(void *argv){
+
+ long long d;
+ d=(long long) argv;
+ int frac=(d * 1000)/trset->numdoc();
+
+ if (!(frac % 10)) fprintf(stderr,"%2d\b\b",frac/10);
+ //fprintf(stderr,"Thread: %lu Document: %d (out of %d)\n",(long)pthread_self(),d,trset->numdoc());
+
+ int r=topics;
+
+
+ int m=trset->doclen(d); //actual length of document
+ int N=m ; // doc length is the same of
+ double totH=0;
+
+ for (int t=0; t<r; t++) if (H[d * r + t] < topicthreshold) H[d * r + t]=0;
+
+
+ //precompute WHij i=0,...,m-1; j fixed
+ float *WH=new float [m]; //initialized to zero
+ memset(WH,0,sizeof(float)*m);
+ for (int t=0; t< r ; t++)
+ if (H[d * r + t]>0)
+ for (int i=0; i<m; i++) //count each word indipendently!!!!
+ WH[i]+=(W[trset->docword(d,i)][t] * H[d * r + t]);
+
+
+ //UPDATE LOCAL Tia (for each word and topic)
+ //seems more efficient perform local computation on complex structures
+ //and perform exclusive computations on simpler structures.
+ float *lT=new float[m * r];
+ memset(lT,0,sizeof(float)*m*r);
+ for (int t=0; t<r; t++)
+ if (H[d * r + t]>0)
+ for (int i=0; i<m; i++)
+ lT[i * r + t]=(W[trset->docword(d,i)][t] * H[d * r + t]/WH[i]);
+
+ //UPDATE GLOBAL T and cplsa_LL
+ pthread_mutex_lock(&cplsa_mut1);
+ for (int i=0; i<m; i++){
+ for (int t=0; t<r; t++)
+ T[trset->docword(d,i)][t]+=(double)lT[i * r + t];
+ cplsa_LL+= log( WH[i] );
+ }
+ pthread_mutex_unlock(&cplsa_mut1);
+
+
+ //UPDATE Haj (topic a and document j)
+ totH=0;
+ for (int t=0; t<r; t++){
+ float tmpHaj=0;
+ if (H[d * r + t]>0){
+ for (int i=0; i < m; i++)
+ tmpHaj+=(W[trset->docword(d,i)][t] * H[d * r + t]/WH[i]);
+ H[d * r + t]=tmpHaj/N;
+ totH+=H[d * r + t];
+ }
+ }
+
+ if(totH>UPPER_SINGLE_PRECISION_OF_1 || totH<LOWER_SINGLE_PRECISION_OF_1){
+ std::stringstream ss_msg;
+ ss_msg << "Total H is wrong; totH=" << totH << " ( doc= " << d << ")\n";
+ exit_error(IRSTLM_ERROR_MODEL, ss_msg.str());
+ }
+
+ delete [] WH;
+ delete [] lT;
+
+};
+
+
+
+int plsa::train(char *trainfile, char *modelfile, int maxiter,float noiseW,int spectopic){
+
+
+ //check if to either use the dict of the modelfile
+ //or create a new one from the data
+ //load training data!
+
+
+ //Initialize W matrix and load training data
+ //notice: if dict is empy, then upload from model
+ initW(modelfile,noiseW,spectopic);
+
+ //Load training data
+ trset=new doc(dict,trainfile);
+
+ //allocate H table
+ initH();
+
+ int iter=0;
+ int r=topics;
+
+ cerr << "Starting training \n";
+ threadpool thpool=thpool_init(threads);
+ task *t=new task[trset->numdoc()];
+
+ pthread_mutex_init(&cplsa_mut1, NULL);
+ //pthread_mutex_init(&cplsa_mut2, NULL);
+
+ while (iter < maxiter){
+ cplsa_LL=0;
+
+ cerr << "Iteration: " << ++iter << " ";
+
+ //initialize T table
+ initT();
+
+ for (long long d=0;d<trset->numdoc();d++){
+ //prepare and assign tasks to threads
+ t[d].ctx=this; t[d].argv=(void *)d;
+ thpool_add_work(thpool, &plsa::expected_counts_helper, (void *)&t[d]);
+
+ }
+ //join all threads
+ thpool_wait(thpool);
+
+ //Recombination and normalization of expected counts
+ for (int t=0; t<r; t++) {
+ double Tsum=0;
+ for (int i=0; i<dict->size(); i++) Tsum+=T[i][t];
+ for (int i=0; i<dict->size(); i++) W[i][t]=(float)(T[i][t]/Tsum);
+ }
+
+
+ cerr << " LL: " << cplsa_LL << "\n";
+ if (trset->numdoc()> 10) system("date");
+
+ saveW(modelfile);
+
+ }
+
+ //destroy thread pool
+ thpool_destroy(thpool);
+
+
+ freeH(); freeT(); freeW();
+
+ delete trset;
+ delete [] t;
+
+ return 1;
+}
+
+
+void plsa::single_inference(void *argv){
+ long long d;
+ d=(long long) argv;
+
+ int frac=(d * 1000)/trset->numdoc();
+ if (!(frac % 10)) fprintf(stderr,"%2d\b\b",frac/10);
+
+ //fprintf(stderr,"Thread: %lu Document: %d (out of %d)\n",(long)pthread_self(),d,trset->numdoc());
+
+ float *WH=new float [dict->size()];
+ bool *Hflags=new bool[topics];
+
+ int M=trset->doclen(d); //vocabulary size of current documents with repetitions
+
+ int N=M; //document length
+
+ //initialize H: we estimate one H for each document
+ for (int t=0; t<topics; t++) {H[(d % bucket) * topics + t]=1/(float)topics;Hflags[t]=true;}
+
+ int iter=0;
+
+ float delta=0;
+ float maxdelta=1;
+
+ while (iter < maxiter && maxdelta > deltathreshold){
+
+ maxdelta=0;
+ iter++;
+
+ //precompute denominator WH
+ for (int t=0; t<topics; t++)
+ if (Hflags[t] && H[(d % bucket) * topics + t] < topicthreshold){ Hflags[t]=false; H[(d % bucket) * topics + t]=0;}
+
+ for (int i=0; i < M ; i++) {
+ WH[trset->docword(d,i)]=0; //initialized
+ for (int t=0; t<topics; t++){
+ if (Hflags[t])
+ WH[trset->docword(d,i)]+=W[trset->docword(d,i)][t] * H[(d % bucket) * topics + t];
+ }
+
+ }
+
+
+
+ //UPDATE H
+ float totH=0;
+ for (int t=0; t<topics; t++) {
+ if (Hflags[t]){
+ float tmpH=0;
+ for (int i=0; i< M ; i++)
+ tmpH+=(W[trset->docword(d,i)][t] * H[(d % bucket) * topics + t]/WH[trset->docword(d,i)]);
+ delta=abs(H[(d % bucket) * topics + t]-tmpH/N);
+ if (delta > maxdelta) maxdelta=delta;
+ H[(d % bucket) * topics + t]=tmpH/N;
+ totH+=H[(d % bucket) * topics + t]; //to check that sum is 1
+ }
+ }
+
+ if(totH>UPPER_SINGLE_PRECISION_OF_1 || totH<LOWER_SINGLE_PRECISION_OF_1) {
+ cerr << "totH " << totH << "\n";
+ std::stringstream ss_msg;
+ ss_msg << "Total H is wrong; totH=" << totH << "\n";
+ exit_error(IRSTLM_ERROR_MODEL, ss_msg.str());
+ }
+
+ }
+ //cerr << "Stopped at iteration " << iter << "\n";
+
+ delete [] WH; delete [] Hflags;
+
+
+}
+
+
+
+int plsa::inference(char *testfile, char* modelfile, int maxit, char* topicfeatfile,char* wordfeatfile){
+
+ if (topicfeatfile) {mfstream out(topicfeatfile,ios::out);} //empty the file
+ //load existing model
+ initW(modelfile,0,0);
+
+ //load existing model
+ trset=new doc(dict,testfile);
+
+ bucket=BUCKET; //initialize the bucket size
+ maxiter=maxit; //set maximum number of iterations
+
+ //use one vector H for all document
+ H=new float[topics*bucket]; memset(H,0,sizeof(float)*(long long)topics*bucket);
+
+ threadpool thpool=thpool_init(threads);
+ task *t=new task[bucket];
+
+
+ cerr << "Start inference: ";
+
+ for (long long d=0;d<trset->numdoc();d++){
+
+ t[d % bucket].ctx=this; t[d % bucket].argv=(void *)d;
+ thpool_add_work(thpool, &plsa::single_inference_helper, (void *)&t[d % bucket]);
+
+ if (((d % bucket) == (bucket-1)) || (d==(trset->numdoc()-1)) ){
+ //join all threads
+ thpool_wait(thpool);
+
+ if ((d % bucket) != (bucket-1))
+ bucket=trset->numdoc() % bucket; //last bucket at end of file
+
+ if (topicfeatfile){
+ mfstream out(topicfeatfile,ios::out | ios::app);
+
+ for (int b=0;b<bucket;b++){ //include the case of
+ out << H[b * topics];
+ for (int t=1; t<topics; t++) out << " " << H[b * topics + t];
+ out << "\n";
+ }
+ }
+ if (wordfeatfile){
+ //cout << "from: " << d-bucket << " to: " << d-1 << "\n";
+ for (int b=0;b<bucket;b++) saveWordFeatures(wordfeatfile,d-bucket+b);
+ }
+
+ }
+
+
+ }
+
+ delete [] H; delete [] t;
+ delete trset;
+ return 1;
+}
+} //namespace irstlm
diff --git a/src/cplsa.h b/src/cplsa.h
new file mode 100755
index 0000000..a57d8c3
--- /dev/null
+++ b/src/cplsa.h
@@ -0,0 +1,85 @@
+/******************************************************************************
+ IrstLM: IRST Language Model Toolkit, compile LM
+ Copyright (C) 2006 Marcello Federico, ITC-irst Trento, Italy
+
+ This library is free software; you can redistribute it and/or
+ modify it under the terms of the GNU Lesser General Public
+ License as published by the Free Software Foundation; either
+ version 2.1 of the License, or (at your option) any later version.
+
+ This library is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+ Lesser General Public License for more details.
+
+ You should have received a copy of the GNU Lesser General Public
+ License along with this library; if not, write to the Free Software
+ Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+
+ ******************************************************************************/
+
+#ifndef MF_CPLSA_H
+#define MF_CPLSA_H
+
+namespace irstlm {
+
+class plsa {
+ dictionary* dict; //dictionary
+ int topics; //number of topics
+ doc* trset; //training/inference set
+
+ double **T; //support matrix (keep double precision here!)
+
+ float **W; //word - topic matrix
+ float *H; //document-topic: matrix (memory mapped)
+
+ char Hfname[100]; //temporary and unique filename for H
+ char *tmpdir;
+ bool memorymap; //use or not memory mapping
+
+ //private info shared among threads
+ int threads;
+ int bucket; //parallel inference
+ int maxiter; //maximum iterations for inference
+ struct task {
+ void *ctx;
+ void *argv;
+ };
+
+public:
+
+
+ plsa(dictionary* dict,int topics,char* workdir,int threads,bool mm);
+ ~plsa();
+
+ int saveW(char* fname);
+ int saveWtxt(char* fname,int tw=10);
+ int loadW(char* fname);
+
+ int initW(char* modelfile, float noise,int spectopic); int freeW();
+ int initH();int freeH();
+ int initT();int freeT();
+
+ void expected_counts(void *argv);
+
+ static void *expected_counts_helper(void *argv){
+ task t=*(task *)argv;
+ ((plsa *)t.ctx)->expected_counts(t.argv);return NULL;
+ };
+
+ static void *single_inference_helper(void *argv){
+ task t=*(task *)argv;
+ ((plsa *)t.ctx)->single_inference(t.argv);return NULL;
+ };
+
+ int train(char *trainfile,char* modelfile, int maxiter, float noiseW,int spectopic=0);
+ int inference(char *trainfile, char* modelfile, int maxiter, char* topicfeatfile,char* wordfeatfile);
+
+ void single_inference(void *argv);
+
+ int saveWordFeatures(char* fname, long long d);
+
+};
+
+} //namespace irstlm
+#endif
diff --git a/src/crc.cpp b/src/crc.cpp
new file mode 100644
index 0000000..eb746a0
--- /dev/null
+++ b/src/crc.cpp
@@ -0,0 +1,75 @@
+/*
+ * Copyright 2001-2010 Georges Menie (www.menie.org)
+ * All rights reserved.
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions are met:
+ *
+ * * Redistributions of source code must retain the above copyright
+ * notice, this list of conditions and the following disclaimer.
+ * * Redistributions in binary form must reproduce the above copyright
+ * notice, this list of conditions and the following disclaimer in the
+ * documentation and/or other materials provided with the distribution.
+ * * Neither the name of the University of California, Berkeley nor the
+ * names of its contributors may be used to endorse or promote products
+ * derived from this software without specific prior written permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE REGENTS AND CONTRIBUTORS ``AS IS'' AND ANY
+ * EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
+ * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+ * DISCLAIMED. IN NO EVENT SHALL THE REGENTS AND CONTRIBUTORS BE LIABLE FOR ANY
+ * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
+ * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+ * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
+ * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+ * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+ * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ */
+
+#include "crc.h"
+
+/* CRC16 implementation acording to CCITT standards */
+
+static const unsigned short crc16tab[256]= {
+ 0x0000,0x1021,0x2042,0x3063,0x4084,0x50a5,0x60c6,0x70e7,
+ 0x8108,0x9129,0xa14a,0xb16b,0xc18c,0xd1ad,0xe1ce,0xf1ef,
+ 0x1231,0x0210,0x3273,0x2252,0x52b5,0x4294,0x72f7,0x62d6,
+ 0x9339,0x8318,0xb37b,0xa35a,0xd3bd,0xc39c,0xf3ff,0xe3de,
+ 0x2462,0x3443,0x0420,0x1401,0x64e6,0x74c7,0x44a4,0x5485,
+ 0xa56a,0xb54b,0x8528,0x9509,0xe5ee,0xf5cf,0xc5ac,0xd58d,
+ 0x3653,0x2672,0x1611,0x0630,0x76d7,0x66f6,0x5695,0x46b4,
+ 0xb75b,0xa77a,0x9719,0x8738,0xf7df,0xe7fe,0xd79d,0xc7bc,
+ 0x48c4,0x58e5,0x6886,0x78a7,0x0840,0x1861,0x2802,0x3823,
+ 0xc9cc,0xd9ed,0xe98e,0xf9af,0x8948,0x9969,0xa90a,0xb92b,
+ 0x5af5,0x4ad4,0x7ab7,0x6a96,0x1a71,0x0a50,0x3a33,0x2a12,
+ 0xdbfd,0xcbdc,0xfbbf,0xeb9e,0x9b79,0x8b58,0xbb3b,0xab1a,
+ 0x6ca6,0x7c87,0x4ce4,0x5cc5,0x2c22,0x3c03,0x0c60,0x1c41,
+ 0xedae,0xfd8f,0xcdec,0xddcd,0xad2a,0xbd0b,0x8d68,0x9d49,
+ 0x7e97,0x6eb6,0x5ed5,0x4ef4,0x3e13,0x2e32,0x1e51,0x0e70,
+ 0xff9f,0xefbe,0xdfdd,0xcffc,0xbf1b,0xaf3a,0x9f59,0x8f78,
+ 0x9188,0x81a9,0xb1ca,0xa1eb,0xd10c,0xc12d,0xf14e,0xe16f,
+ 0x1080,0x00a1,0x30c2,0x20e3,0x5004,0x4025,0x7046,0x6067,
+ 0x83b9,0x9398,0xa3fb,0xb3da,0xc33d,0xd31c,0xe37f,0xf35e,
+ 0x02b1,0x1290,0x22f3,0x32d2,0x4235,0x5214,0x6277,0x7256,
+ 0xb5ea,0xa5cb,0x95a8,0x8589,0xf56e,0xe54f,0xd52c,0xc50d,
+ 0x34e2,0x24c3,0x14a0,0x0481,0x7466,0x6447,0x5424,0x4405,
+ 0xa7db,0xb7fa,0x8799,0x97b8,0xe75f,0xf77e,0xc71d,0xd73c,
+ 0x26d3,0x36f2,0x0691,0x16b0,0x6657,0x7676,0x4615,0x5634,
+ 0xd94c,0xc96d,0xf90e,0xe92f,0x99c8,0x89e9,0xb98a,0xa9ab,
+ 0x5844,0x4865,0x7806,0x6827,0x18c0,0x08e1,0x3882,0x28a3,
+ 0xcb7d,0xdb5c,0xeb3f,0xfb1e,0x8bf9,0x9bd8,0xabbb,0xbb9a,
+ 0x4a75,0x5a54,0x6a37,0x7a16,0x0af1,0x1ad0,0x2ab3,0x3a92,
+ 0xfd2e,0xed0f,0xdd6c,0xcd4d,0xbdaa,0xad8b,0x9de8,0x8dc9,
+ 0x7c26,0x6c07,0x5c64,0x4c45,0x3ca2,0x2c83,0x1ce0,0x0cc1,
+ 0xef1f,0xff3e,0xcf5d,0xdf7c,0xaf9b,0xbfba,0x8fd9,0x9ff8,
+ 0x6e17,0x7e36,0x4e55,0x5e74,0x2e93,0x3eb2,0x0ed1,0x1ef0
+};
+
+unsigned short crc16_ccitt(const char *buf, int len)
+{
+ register int counter;
+ register unsigned short crc = 0;
+ for( counter = 0; counter < len; counter++)
+ crc = (crc<<8) ^ crc16tab[((crc>>8) ^ *(char *)buf++)&0x00FF];
+ return crc;
+}
+
diff --git a/src/crc.h b/src/crc.h
new file mode 100644
index 0000000..b7a1166
--- /dev/null
+++ b/src/crc.h
@@ -0,0 +1,33 @@
+/*
+ * Copyright 2001-2010 Georges Menie (www.menie.org)
+ * All rights reserved.
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions are met:
+ *
+ * * Redistributions of source code must retain the above copyright
+ * notice, this list of conditions and the following disclaimer.
+ * * Redistributions in binary form must reproduce the above copyright
+ * notice, this list of conditions and the following disclaimer in the
+ * documentation and/or other materials provided with the distribution.
+ * * Neither the name of the University of California, Berkeley nor the
+ * names of its contributors may be used to endorse or promote products
+ * derived from this software without specific prior written permission.
+ *
+ * THIS SOFTWARE IS PROVIDED BY THE REGENTS AND CONTRIBUTORS ``AS IS'' AND ANY
+ * EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
+ * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+ * DISCLAIMED. IN NO EVENT SHALL THE REGENTS AND CONTRIBUTORS BE LIABLE FOR ANY
+ * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
+ * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+ * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
+ * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+ * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+ * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ */
+
+#ifndef _CRC16_H_
+#define _CRC16_H_
+
+unsigned short crc16_ccitt(const char *buf, int len);
+
+#endif /* _CRC16_H_ */
diff --git a/src/cswa.cpp b/src/cswa.cpp
new file mode 100755
index 0000000..e796578
--- /dev/null
+++ b/src/cswa.cpp
@@ -0,0 +1,214 @@
+/******************************************************************************
+ IrstLM: IRST Language Model Toolkit, compile LM
+ Copyright (C) 2006 Marcello Federico, ITC-irst Trento, Italy
+
+ This library is free software; you can redistribute it and/or
+ modify it under the terms of the GNU Lesser General Public
+ License as published by the Free Software Foundation; either
+ version 2.1 of the License, or (at your option) any later version.
+
+ This library is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+ Lesser General Public License for more details.
+
+ You should have received a copy of the GNU Lesser General Public
+ License along with this library; if not, write to the Free Software
+ Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+
+ ******************************************************************************/
+
+
+#include <iostream>
+#include "cmd.h"
+#include <pthread.h>
+#include "thpool.h"
+#include "util.h"
+#include "mfstream.h"
+#include "mempool.h"
+#include "htable.h"
+#include "dictionary.h"
+#include "n_gram.h"
+#include "ngramtable.h"
+#include "doc.h"
+#include "cswam.h"
+
+using namespace std;
+using namespace irstlm;
+
+void print_help(int TypeFlag=0){
+ std::cerr << std::endl << "cswa - continuous space word alignment model" << std::endl;
+ std::cerr << std::endl << "USAGE:" << std::endl;
+ std::cerr << " Training mode:" << std::endl;
+ std::cerr << " cswa -sd=<src-data> -td=<trg-data> -w2v=<word2vec> -m=<model> -it=<iterations> -th=<threads> [options]" << std::endl;
+ std::cerr << " Alignment mode:" << std::endl;
+ std::cerr << " cswa -sd=<src-data> -td=<trg-data> -w2v=<word2vec> -m=<model> -al=<alignment-file> -th=<threads> [options]" << std::endl;
+ std::cerr << " Data format:" << std::endl;
+ std::cerr << " <src-data> and <trg-data> must have an header with the number of following lines. " << std::endl;
+ std::cerr << " Each text line must be sourrounded by the symbols <d> and </d>. " << std::endl;
+ std::cerr << " Hint: (echo `wc -l < yourfile`; add-start-end.sh -s \"d\" < yourfile) > yourfile.doc " << std::endl;
+
+ std::cerr << std::endl;
+
+ FullPrintParams(TypeFlag, 0, 1, stderr);
+}
+
+void usage(const char *msg = 0)
+{
+ if (msg){
+ std::cerr << msg << std::endl;
+ }
+ else{
+ print_help();
+ }
+}
+
+int main(int argc, char **argv){
+
+ char *srcdatafile=NULL;
+ char *trgdatafile=NULL;
+
+ char *w2vfile=NULL;
+ char *modelfile=NULL;
+ char *modeltxtfile=NULL;
+ char *alignfile=NULL;
+
+ bool forcemodel=false;
+
+ int iterations=0; //number of EM iterations to run
+ int threads=1; //current EM iteration for multi-thread training
+ bool help=false;
+ bool trainvar=true;
+ bool normvectors=false;
+ bool usenullword=true;
+ double fixnullprob=0;
+ bool verbosity=false;
+ double minvar=0.2;
+ bool distmean=true;
+ bool distvar=true;
+ bool distbeta=false;
+ int model1iter=7;
+ int distwin=8;
+
+ DeclareParams((char*)
+
+
+ "SrcData", CMDSTRINGTYPE|CMDMSG, &srcdatafile, "<fname> : source text collection ",
+ "sd", CMDSTRINGTYPE|CMDMSG, &srcdatafile, "<fname> : source text collection ",
+
+ "TrgData", CMDSTRINGTYPE|CMDMSG, &trgdatafile, "<fname> : target text collection ",
+ "td", CMDSTRINGTYPE|CMDMSG, &trgdatafile, "<fname> : target text collection ",
+
+ "Word2Vec", CMDSTRINGTYPE|CMDMSG, &w2vfile, "<fname> : word2vec file ",
+ "w2v", CMDSTRINGTYPE|CMDMSG, &w2vfile, "<fname> : word2vec file ",
+
+ "Model", CMDSTRINGTYPE|CMDMSG, &modelfile, "<fname> : model file",
+ "m", CMDSTRINGTYPE|CMDMSG, &modelfile, "<fname> : model model file",
+
+ "Iterations", CMDINTTYPE|CMDMSG, &iterations, "<count> : training iterations",
+ "it", CMDINTTYPE|CMDMSG, &iterations, "<count> : training iterations",
+
+ "Alignments", CMDSTRINGTYPE|CMDMSG, &alignfile, "<fname> : output alignment file",
+ "al", CMDSTRINGTYPE|CMDMSG, &alignfile, "<fname> : output alignment file",
+
+ "UseNullWord", CMDBOOLTYPE|CMDMSG, &usenullword, "<bool>: use null word (default true)",
+ "unw", CMDBOOLTYPE|CMDMSG, &usenullword, "<bool>: use null word (default true)",
+
+ "Threads", CMDINTTYPE|CMDMSG, &threads, "<count>: number of threads (default 2)",
+ "th", CMDINTTYPE|CMDMSG, &threads, "<count>: number of threads (default 2)",
+
+ "ForceModel", CMDBOOLTYPE|CMDMSG, &forcemodel, "<bool>: force to use existing model for training",
+ "fm", CMDBOOLTYPE|CMDMSG, &forcemodel, "<bool>: force to use existing model for training",
+
+ "TrainVariances", CMDBOOLTYPE|CMDMSG, &trainvar, "<bool>: train variances (default true)",
+ "tv", CMDBOOLTYPE|CMDMSG, &trainvar, "<bool>: train variances (default true)",
+
+ "FixNullProb", CMDDOUBLETYPE|CMDMSG, &fixnullprob, "<value>: fix null probability (default estimate)",
+ "fnp", CMDDOUBLETYPE|CMDMSG, &fixnullprob, "<value>: fix null probability (default estimate)",
+
+ "MinVariance", CMDDOUBLETYPE|CMDMSG, &minvar, "<value>: minimum variance (default 0.01)",
+ "mv", CMDDOUBLETYPE|CMDMSG, &minvar, "<value>: minimum variance (default 0.01)",
+
+ "NormalizeVectors", CMDBOOLTYPE|CMDMSG, &normvectors, "<bool>: normalize vectors (default false)",
+ "nv", CMDBOOLTYPE|CMDMSG, &normvectors, "<bool>: normalize vectors (default false)",
+
+ "DistVar", CMDBOOLTYPE|CMDMSG, &distvar, "<bool>: use distortion variance (default true)",
+ "dv", CMDBOOLTYPE|CMDMSG, &distvar, "<bool>: use distortion variance (default true)",
+
+ "DistMean", CMDBOOLTYPE|CMDMSG, &distmean, "<bool>: use distortion mean (default true)",
+ "dm", CMDBOOLTYPE|CMDMSG, &distmean, "<bool>: use distortion mean (default true)",
+
+ "DistBeta", CMDBOOLTYPE|CMDMSG, &distbeta, "<bool>: use beta distribution for distortion (default true)",
+ "db", CMDBOOLTYPE|CMDMSG, &distbeta, "<bool>: use beta distribution for distortion (default true)",
+
+ "TxtModel", CMDSTRINGTYPE|CMDMSG, &modeltxtfile, "<fname> : model in textual form",
+ "txt", CMDSTRINGTYPE|CMDMSG, &modeltxtfile, "<fname> : model in readable form",
+
+ "DistWin", CMDINTTYPE|CMDMSG, &distwin, "<count>: distortion window (default 8)",
+ "dw", CMDINTTYPE|CMDMSG, &distwin, "<count>: distortion window (default 8)",
+
+ "M1iter", CMDINTTYPE|CMDMSG, &model1iter, "<count>: number of itereations with model 1 (default 7)",
+ "m1", CMDINTTYPE|CMDMSG, &model1iter, "<count>: number of itereations with model 1 (default 7)",
+
+ "Verbosity", CMDBOOLTYPE|CMDMSG, &verbosity, "verbose output",
+ "v", CMDBOOLTYPE|CMDMSG, &verbosity, "verbose output",
+
+
+ "Help", CMDBOOLTYPE|CMDMSG, &help, "print this help",
+ "h", CMDBOOLTYPE|CMDMSG, &help, "print this help",
+
+ (char *)NULL
+ );
+
+ if (argc == 1){
+ usage();
+ exit_error(IRSTLM_NO_ERROR);
+ }
+
+ GetParams(&argc, &argv, (char*) NULL);
+
+ if (help){
+ usage();
+ exit_error(IRSTLM_NO_ERROR);
+ }
+
+
+ if (!srcdatafile || !trgdatafile || !w2vfile || !modelfile ) {
+ usage();
+ exit_error(IRSTLM_ERROR_DATA,"Missing parameters");
+ }
+
+ //check if model is readable
+ bool testmodel=false;
+ FILE* f;if ((f=fopen(modelfile,"r"))!=NULL){fclose(f);testmodel=true;}
+
+ if (iterations && testmodel && !forcemodel)
+ exit_error(IRSTLM_ERROR_DATA,"Use -ForceModel=y option to update an existing model.");
+
+ cswam *model=new cswam(srcdatafile,trgdatafile,w2vfile,
+ forcemodel,
+ usenullword,fixnullprob,
+ normvectors,
+ model1iter,
+ trainvar,minvar,
+ distwin,distbeta, distmean,distvar,
+ verbosity);
+
+ if (iterations)
+ model->train(srcdatafile,trgdatafile,modelfile,iterations,threads);
+
+ if (alignfile)
+ model->test(srcdatafile,trgdatafile,modelfile,alignfile,threads);
+
+ if (modeltxtfile){
+ model->loadModel(modelfile);
+ model->saveModelTxt(modeltxtfile);
+ }
+
+ delete model;
+
+ exit_error(IRSTLM_NO_ERROR);
+}
+
+
+
diff --git a/src/cswam.cpp b/src/cswam.cpp
new file mode 100755
index 0000000..d739989
--- /dev/null
+++ b/src/cswam.cpp
@@ -0,0 +1,1484 @@
+/******************************************************************************
+ IrstLM: IRST Language Model Toolkit, compile LM
+ Copyright (C) 2006 Marcello Federico, ITC-irst Trento, Italy
+
+ This library is free software; you can redistribute it and/or
+ modify it under the terms of the GNU Lesser General Public
+ License as published by the Free Software Foundation; either
+ version 2.1 of the License, or (at your option) any later version.
+
+ This library is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+ Lesser General Public License for more details.
+
+ You should have received a copy of the GNU Lesser General Public
+ License along with this library; if not, write to the Free Software
+ Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+
+ *******************************************************************************/
+
+#include <sys/mman.h>
+#include <stdio.h>
+#include <cmath>
+#include <limits>
+#include <string>
+#include <sstream>
+#include <pthread.h>
+#include "thpool.h"
+#include "crc.h"
+#include "mfstream.h"
+#include "mempool.h"
+#include "htable.h"
+#include "n_gram.h"
+#include "util.h"
+#include "dictionary.h"
+#include "ngramtable.h"
+#include "doc.h"
+#include <algorithm>
+#include <vector>
+#include "cswam.h"
+
+using namespace std;
+
+namespace irstlm {
+
+cswam::cswam(char* sdfile,char *tdfile, char* w2vfile,
+ bool forcemodel,
+ bool usenull,double fixnullprob,
+ bool normvect,
+ int model1iter,
+ bool trainvar,float minvar,
+ int distwin,bool distbeta,bool distmean,bool distvar,
+ bool verbose){
+
+ //actual model structure
+
+ TM=NULL;
+ A=NULL;
+ Den=NULL;
+ friends=NULL;
+ efcounts=NULL;
+ ecounts=NULL;
+ loc_efcounts=NULL;
+ loc_ecounts=NULL;
+
+ //setting
+ incremental_train=forcemodel;
+ normalize_vectors=normvect;
+ train_variances=trainvar;
+ use_null_word=usenull;
+ min_variance=minvar;
+ distortion_window=distwin;
+ distortion_mean=distmean;
+ distortion_var=distvar;
+ use_beta_distortion=distbeta;
+ fix_null_prob=fixnullprob;
+ DistMean=DistVar=0; //distortion mean and variance
+ DistA=DistB=0; //beta parameters
+ NullProb=0;
+ M1iter=model1iter;
+
+ //set mininum word frequency to collect friends
+ minfreq=10;
+
+
+ cout << "cswam configuration.\n";
+ cout << "Vectors: normalize [" << normalize_vectors << "] \n";
+ cout << "Gaussian Variances: train [" << train_variances << "] min [" << min_variance << "] initial [" << min_variance * SSEED << "]\n";
+ cout << "Null word: active [" << use_null_word << "] fix_null_prob [" << fix_null_prob << "]\n";
+ cout << "Distortion model: window [" << distortion_window << "] use beta [" << use_beta_distortion << "] update mean [" << distortion_mean << "] update variance [" << distortion_var << "]\n";
+
+
+ srandom(100); //ensure repicable generation of random numbers
+ bucket=BUCKET;
+ threads=1;
+ verbosity=verbose;
+
+ //create dictionaries
+ srcdict=new dictionary(NULL,100000); srcdict->generate(sdfile,true);
+ trgdict=new dictionary(NULL,100000); trgdict->generate(tdfile,true);
+
+ //make aware of oov word
+ srcdict->encode(srcdict->OOV());
+ trgdict->encode(trgdict->OOV());
+
+ trgBoD = trgdict->encode(trgdict->BoD()); //codes for begin/end sentence markers
+ trgEoD = trgdict->encode(trgdict->EoD());
+
+ srcBoD = srcdict->encode(srcdict->BoD()); //codes for begin/end sentence markers
+ srcEoD = srcdict->encode(srcdict->EoD());
+
+
+ //load word2vec dictionary
+ W2V=NULL; D=0;
+ loadword2vec(w2vfile);
+
+ //check consistency of word2vec with target vocabulary
+
+
+}
+
+cswam::~cswam() {
+
+ assert(A==NULL);
+
+ if (TM){
+ cerr << "Releasing memory of Translation Model\n";
+ for (int e=0;e<trgdict->size();e++){
+ for (int n=0;n<TM[e].n;n++){
+ delete TM[e].G[n].M;delete TM[e].G[n].S;
+ }
+ delete [] TM[e].G; delete [] TM[e].W;
+ }
+ delete [] TM;
+ }
+ if (W2V){
+ cerr << "Releasing memory of W2W\n";
+ for (int f=0;f<srcdict->size();f++)
+ if (W2V[f]!=NULL) delete [] W2V[f];
+ delete [] W2V;
+ }
+
+ if (friends) delete [] friends;
+
+ cerr << "Releasing memory of srcdict\n";
+ delete srcdict;
+ cerr << "Releasing memory of srcdict\n";
+ delete trgdict;
+
+
+}
+
+void cswam::randword2vec(const char* word,float* vec,int it){
+
+ //initialize random generator
+ srandom(crc16_ccitt(word,strlen(word))+it);
+
+ //generate random numbers between -1 and +1,
+ //then scale and shift according to w2v
+ for (int d=0;d<D;d++)
+ vec[d]=(float)(MY_RAND * SSEED * min_variance);
+}
+
+
+void cswam::loadword2vec(char* fname){
+
+ cerr << "Loading word2vec file " << fname;
+ mfstream inp(fname,ios::in);
+
+ long long w2vsize;
+ inp >> w2vsize; cerr << " size= " << w2vsize;
+ inp >> D ; cout << " dim= " << D << "\n";
+
+ assert(D>0 && D<1000);
+
+ int srcoov=srcdict->oovcode();
+
+ W2V=new float* [srcdict->size()];
+ for (int f=0;f<srcdict->size();f++) W2V[f]=NULL;
+
+ char word[100]; float dummy; int f;
+
+ for (long long i=0;i<w2vsize;i++){
+ inp >> word;
+ f=srcdict->encode(word);
+ if (f != srcoov){
+ W2V[f]=new float[D];
+ for (int d=0;d<D;d++) inp >> W2V[f][d];
+ }
+ else //skip this word
+ for (int d=0;d<D;d++) inp >> dummy;
+
+ if (!(i % 10000)) cerr<< ".";
+ }
+ cerr << "\n";
+
+
+ cerr << "looking for missing source words in w2v\n";
+ int newwords=0;
+ for ( f=0;f<srcdict->size();f++){
+ if (W2V[f]==NULL && f!=srcBoD && f!=srcEoD) {
+ if (verbosity)
+ cerr << "Missing src word in w2v: [" << f << "] " << srcdict->decode(f) << "\n";
+
+ W2V[f]=new float[D];
+
+ //generate random vectors with same distribution
+ randword2vec(srcdict->decode(f),W2V[f]);
+
+ newwords++;
+
+ if (verbosity){
+ for (int d=0;d<D;d++) cerr << " " << W2V[f][d]; cerr << "\n";}
+
+ }
+ }
+
+ cerr << "Generated " << newwords << " missing vectors\n";
+
+
+ if (normalize_vectors){
+ cerr << "Normalizing vectors\n";
+ for (f=0;f<srcdict->size();f++)
+ if (W2V[f]!=NULL){
+ float norm=0;
+ for (int d=0;d<D;d++) norm+=W2V[f][d]*W2V[f][d];
+ norm=sqrt(norm);
+ for (int d=0;d<D;d++) W2V[f][d]/=norm;
+ }
+ }
+
+};
+
+void cswam::initEntry(int e){
+ assert(TM[e].G==NULL);
+
+ //allocate a suitable number of gaussians
+
+ TM[e].n=(friends && friends[e].size()?friends[e].size():1);
+
+ assert(TM[e].n>0);
+ TM[e].G=new Gaussian [TM[e].n];TM[e].W=new float[TM[e].n];
+ for (int n=0;n<TM[e].n;n++){
+
+ TM[e].G[n].M=new float [D];
+ TM[e].G[n].S=new float [D];
+
+ TM[e].G[n].eC=0;
+ TM[e].G[n].mS=0;
+
+ TM[e].W[n]=1.0/(float)TM[e].n;
+
+ if (friends && friends[e].size()){
+ int f=friends[e][n].word; //initialize with source vector
+ memcpy(TM[e].G[n].M,W2V[f],sizeof(float) * D);
+ }
+ else{
+ randword2vec(trgdict->decode(e),TM[e].G[n].M,n);
+ }
+
+ for (int d=0;d<D;d++)
+ TM[e].G[n].S[d]=min_variance * SSEED; //take a wide standard deviation
+
+ }
+
+}
+
+
+
+//void oldinitEntry(int e){
+//
+// assert(TM[e].G==NULL);
+//
+// //allocate a suitable number of gaussians
+// TM[e].n=(int)ceil(log((double)trgdict->freq(e)+1.1));
+//
+// //some exceptions if
+//
+// assert(TM[e].n>0);
+// TM[e].G=new Gaussian [TM[e].n];TM[e].W=new float[TM[e].n];
+// for (int n=0;n<TM[e].n;n++){
+//
+// TM[e].G[n].M=new float [D];
+// TM[e].G[n].S=new float [D];
+//
+// TM[e].G[n].eC=0;
+// TM[e].G[n].mS=0;
+//
+// TM[e].W[n]=1.0/(float)TM[e].n;
+//
+// //initialize with w2v value if the same word is also in src
+// int f=srcdict->encode(trgdict->decode(e));
+// float srcfreq=srcdict->freq(f);float trgfreq=trgdict->freq(e);
+// if (f!=srcdict->oovcode() && srcfreq/trgfreq < 1.1 && srcfreq/trgfreq > 0.9 && srcfreq < 10 && f!=srcBoD && f!=srcEoD){
+// memcpy(TM[e].G[n].M,W2V[f],sizeof(float) * D);
+// for (int d=0;d<D;d++) TM[e].G[n].S[d]=min_variance; //dangerous!!!!
+// if (verbosity) cerr << "Biasing verbatim translation of " << srcdict->decode(f) << "\n";
+// }else{
+// //pick candidates from friends
+//
+//
+// randword2vec(trgdict->decode(e),TM[e].G[n].M,n);
+//
+// for (int d=0;d<D;d++)
+// TM[e].G[n].S[d]=W2Vsd[d] * 10; //take a wide standard deviation
+// }
+// }
+//
+//}
+
+void cswam::initModel(char* modelfile){
+
+ //test if model is readable
+ bool model_available=false;
+ FILE* f;if ((f=fopen(modelfile,"r"))!=NULL){fclose(f);model_available=true;}
+
+ if (model_available)
+ loadModel(modelfile,true); //we are in training mode!
+ else{
+ cerr << "Initialize model\n";
+
+ if (use_beta_distortion){
+ DistMean=0.5;DistVar=1.0/12.0; //uniform distribution on 0,1
+ EstimateBeta(DistA,DistB,DistMean,DistVar);
+ }else{
+ DistMean=0;DistVar=10; //gaussian distribution over -1,+1: almost uniform
+ }
+
+ TM=new TransModel[trgdict->size()];
+
+ friends=new FriendList[trgdict->size()];
+ findfriends(friends);
+
+ for (int e=0; e<trgdict->size(); e++) initEntry(e);
+
+ }
+ //this can overwrite existing model
+ if (use_null_word)
+ NullProb=(fix_null_prob?fix_null_prob:0.05); //null word alignment probability
+
+}
+
+int cswam::saveModelTxt(char* fname){
+ cerr << "Writing model into: " << fname << "\n";
+ mfstream out(fname,ios::out);
+ out << "=dist= " << DistMean << " " << DistVar << "\n";
+ out << "=nullprob= " << NullProb << "\n";
+ for (int e=0; e<trgdict->size(); e++){
+ out << "=h= " << trgdict->decode(e) << " sz= " << TM[e].n << "\n";
+ for (int n=0;n<TM[e].n;n++){
+ out << "=w= " << trgdict->decode(e) << " w= " << TM[e].W[n] << " eC= " << TM[e].G[n].eC << " mS= " << TM[e].G[n].mS << "\n";
+ out << "=m= " << trgdict->decode(e); for (int d=0;d<D;d++) out << " " << TM[e].G[n].M[d] ;out << "\n";
+ out << "=s= " << trgdict->decode(e); for (int d=0;d<D;d++) out << " " << TM[e].G[n].S[d]; out << "\n";
+ }
+ }
+ return 1;
+}
+
+int cswam::saveModel(char* fname){
+ cerr << "Saving model into: " << fname << " ...";
+ mfstream out(fname,ios::out);
+ out << "CSWAM " << D << "\n";
+ trgdict->save(out);
+ out.write((const char*)&DistMean,sizeof(float));
+ out.write((const char*)&DistVar,sizeof(float));
+ out.write((const char*)&NullProb,sizeof(float));
+ for (int e=0; e<trgdict->size(); e++){
+ out.write((const char*)&TM[e].n,sizeof(int));
+ out.write((const char*)TM[e].W,TM[e].n * sizeof(float));
+ for (int n=0;n<TM[e].n;n++){
+ out.write((const char*)TM[e].G[n].M,sizeof(float) * D);
+ out.write((const char*)TM[e].G[n].S,sizeof(float) * D);
+ }
+ }
+ out.close();
+ cerr << "\n";
+ return 1;
+}
+
+int cswam::loadModel(char* fname,bool expand){
+
+ cerr << "Loading model from: " << fname << "...";
+ mfstream inp(fname,ios::in);
+ char header[100];
+ inp.getline(header,100);
+ cerr << header ;
+ int r;
+ sscanf(header,"CSWAM %d\n",&r);
+ if (D>0 && r != D)
+ exit_error(IRSTLM_ERROR_DATA, "incompatible dimension in model");
+ else
+ D=r;
+
+ if (verbosity) cerr << "\nLoading dictionary ... ";
+ dictionary* dict=new dictionary(NULL,1000000);
+ dict->load(inp);
+ dict->encode(dict->OOV());
+ int current_size=dict->size();
+
+ //expand the model for training or keep the model fixed for testing
+ if (expand){
+ if (verbosity)
+ cerr << "\nExpanding model to include targer dictionary";
+ dict->incflag(1);
+ for (int code=0;code<trgdict->size();code++)
+ dict->encode(trgdict->decode(code));
+ dict->incflag(2);
+ }
+ //replace the trgdict with the model dictionary
+ delete trgdict;trgdict=dict;
+ trgdict->encode(trgdict->OOV()); //updated dictionary codes
+ trgBoD = trgdict->encode(trgdict->BoD()); //codes for begin/end sentence markers
+ trgEoD = trgdict->encode(trgdict->EoD());
+
+
+ TM=new TransModel [trgdict->size()];
+
+ if (verbosity) cerr << "\nReading parameters .... ";
+ inp.read((char*)&DistMean, sizeof(float));
+ inp.read((char*)&DistVar, sizeof(float));
+ inp.read((char*)&NullProb,sizeof(float));
+
+ cerr << "DistMean: " << DistMean << " DistVar: " << DistVar << " NullProb: " << NullProb << "\n";
+ if (use_beta_distortion)
+ EstimateBeta(DistA,DistB,DistMean,DistVar);
+
+ for (int e=0; e<current_size; e++){
+ inp.read((char *)&TM[e].n,sizeof(int));
+ TM[e].W=new float[TM[e].n];
+ inp.read((char *)TM[e].W,sizeof(float) * TM[e].n);
+ TM[e].G=new Gaussian[TM[e].n];
+ for (int n=0;n<TM[e].n;n++){
+ TM[e].G[n].M=new float [D];
+ TM[e].G[n].S=new float [D];
+ inp.read((char *)TM[e].G[n].M,sizeof(float) * D);
+ inp.read((char *)TM[e].G[n].S,sizeof(float) * D);
+ TM[e].G[n].eC=0;TM[e].G[n].mS=0;
+ }
+ }
+ inp.close();
+
+ cerr << "\nInitializing " << trgdict->size()-current_size << " new entries .... ";
+ for (int e=current_size; e<trgdict->size(); e++) initEntry(e);
+ cerr << "\nDone\n";
+ return 1;
+}
+
+void cswam::initAlphaDen(){
+
+ //install Alpha[s][i][j] to collect counts
+ //allocate if empty
+
+ if (A==NULL){
+ assert(trgdata->numdoc()==srcdata->numdoc());
+ A=new float ***[trgdata->numdoc()];
+ for (int s=0;s<trgdata->numdoc();s++){
+ A[s]=new float **[trgdata->doclen(s)];
+ for (int i=0;i<trgdata->doclen(s);i++){
+ A[s][i]=new float *[TM[trgdata->docword(s,i)].n];
+ for (int n=0;n<TM[trgdata->docword(s,i)].n;n++)
+ A[s][i][n]=new float [srcdata->doclen(s)];
+ }
+ }
+ }
+ //initialize
+ for (int s=0;s<trgdata->numdoc();s++)
+ for (int i=0;i<trgdata->doclen(s);i++)
+ for (int n=0;n<TM[trgdata->docword(s,i)].n;n++)
+ memset(A[s][i][n],0,sizeof(float) * srcdata->doclen(s));
+
+ //allocate
+ if (Den==NULL){
+ Den=new float*[trgdict->size()];
+ for (int e=0;e<trgdict->size();e++)
+ Den[e]=new float[TM[e].n];
+ }
+
+ //initialize
+ for (int e=0;e<trgdict->size();e++)
+ memset(Den[e],0,sizeof(float)*TM[e].n);
+}
+
+void cswam::freeAlphaDen(){
+
+ if (A!=NULL){
+ for (int s=0;s<trgdata->numdoc();s++){
+ for (int i=0;i<trgdata->doclen(s);i++){
+ for (int n=0;n<TM[trgdata->docword(s,i)].n;n++)
+ delete [] A[s][i][n];
+ delete [] A[s][i];
+ }
+ delete [] A[s];
+ }
+ delete [] A;
+ A=NULL;
+ }
+
+ if (Den!=NULL){
+ for (int e=0;e<trgdict->size();e++) delete [] Den[e];
+ delete [] Den; Den=NULL;
+ }
+
+}
+
+///*****
+//pthread_mutex_t cswam_mut1;
+//pthread_mutex_t cswam_mut2;
+double cswam_LL=0; //Log likelihood
+
+float logsum(float a,float b){
+ if (b<a) return a + logf(1 + expf(b-a));
+ else return b + logf(1+ expf(a-b));
+}
+
+int global_i=0;
+int global_j=0;
+
+float cswam::LogGauss(const int dim,const float* x,const float *m, const float *s){
+
+ static float log2pi=1.83787; //log(2 pi)
+ float dist=0; float norm=0;
+
+ for (int i=0;i<dim;i++){
+ assert(s[i]>0);
+ dist+=(x[i]-m[i])*(x[i]-m[i])/(s[i]);
+ norm+=s[i];
+ }
+
+ return -0.5 * (dist + dim * log2pi + logf(norm));
+
+}
+
+
+float cswam::LogBeta( float x,float a,float b){
+
+ assert(x>0 && x <1);
+
+ //disregard constant factor!
+
+ return (a-1) * log(x) + (b-1) * log(1-x);
+
+}
+
+
+float cswam::Delta(int i,int j,int l,int m){
+
+ i-=(use_null_word?1:0);
+ l-=(use_null_word?1:0);
+
+ float d=((i - j)>0?(float)(i-j)/l:(float)(i-j)/m); //range is [-1,+1];
+ if (use_beta_distortion) d=(d+1)/2; //move in range [0,1];
+
+ //reduce length penalty for short sentences
+ if (l<=6 || m<=6) d/=2;
+
+ return d;
+}
+
+float cswam::LogDistortion(float d){
+
+ if (use_beta_distortion)
+ return LogBeta(d,DistA,DistB);
+ else
+ return LogGauss(1,&d,&DistMean,&DistVar);
+
+}
+
+
+
+void cswam::expected_counts(void *argv){
+
+ long long s=(long long) argv;
+
+ ShowProgress(s, srcdata->numdoc());
+
+ int trglen=trgdata->doclen(s); // length of target sentence
+ int srclen=srcdata->doclen(s); //length of source sentence
+
+ float den;float delta=0; //distortion
+
+ //reset likelihood
+ localLL[s]=0;
+
+ //compute denominator for each source-target pair
+ for (int j=0;j<srclen;j++){
+ //qcout << "j: " << srcdict->decode(srcdata->docword(s,j)) << "\n";
+ den=0;
+ for (int i=0;i<trglen;i++)
+ if ((use_null_word && i==0) || abs(i-j-1) <= distortion_window){
+ delta=Delta(i,j,trglen,srclen);
+ for (int n=0;n<TM[trgdata->docword(s,i)].n;n++){
+ if (!(TM[trgdata->docword(s,i)].W[n]>0))
+ cerr << trgdict->decode(trgdata->docword(s,i)) << " n:" << n << "\n";
+ assert(TM[trgdata->docword(s,i)].W[n]>0); //weight zero must be prevented!!!
+ //global_i=i;
+ //cout << "i: " << trgdict->decode(trgdata->docword(s,i)) << "\n";
+ A[s][i][n][j]=LogGauss(D, W2V[srcdata->docword(s,j)],
+ TM[trgdata->docword(s,i)].G[n].M,
+ TM[trgdata->docword(s,i)].G[n].S)
+ +log(TM[trgdata->docword(s,i)].W[n])
+ +(i>0 || !use_null_word ?logf(1-NullProb):logf(NullProb))
+ +(i>0 || !use_null_word ?LogDistortion(delta):0);
+
+ if (i==0 && n==0) //den must be initialized
+ den=A[s][i][n][j];
+ else
+ den=logsum(den,A[s][i][n][j]);
+ }
+ }
+ //update local likelihood
+ localLL[s]+=den;
+
+ for (int i=0;i<trglen;i++)
+ if ((use_null_word && i==0) || abs(i-j-1) <= distortion_window)
+ for (int n=0;n<TM[trgdata->docword(s,i)].n;n++){
+
+ assert(A[s][i][n][j]<= den);
+
+ A[s][i][n][j]=expf(A[s][i][n][j]-den); // A is now a regular expected count
+
+ if (A[s][i][n][j]<0.000000001) A[s][i][n][j]=0; //take mall risk of wrong normalization
+
+ if (A[s][i][n][j]>0) TM[trgdata->docword(s,i)].G[n].eC++; //increase support set size
+
+ }
+ }
+
+
+
+}
+
+void cswam::EstimateBeta(float &a, float &b, float m, float s){
+
+ b = (s * m -s + m * m * m - 2 * m * m + m)/s;
+ a = ( m * b )/(1-m);
+}
+
+
+void cswam::maximization(void *argv){
+
+ long long d=(long long) argv;
+
+ ShowProgress(d,D);
+
+ if (d==D){
+ //this thread is to maximize the global distortion model
+ //Maximization step: Mean and variance of distortion model
+
+ //Mean
+
+ double totwdist=0, totdistprob=0, totnullprob=0, delta=0;
+ for (int s=0;s<srcdata->numdoc();s++){
+ for (int j=0;j<srcdata->doclen(s);j++)
+ for (int i=0;i<trgdata->doclen(s);i++)
+ if ((use_null_word && i==0) || abs(i-j-1) <= distortion_window){
+ delta=Delta(i,j,trgdata->doclen(s),srcdata->doclen(s));
+ for (int n=0;n<TM[trgdata->docword(s,i)].n;n++)
+ if (A[s][i][n][j]>0){
+ if (i>0 || !use_null_word){
+ totwdist+=A[s][i][n][j]*delta;
+ totdistprob+=A[s][i][n][j];
+ }
+ else{
+ totnullprob+=A[s][i][n][j];
+ }
+ }
+ }
+ }
+
+ if (use_null_word && fix_null_prob==0)
+ NullProb=(float)totnullprob/(totdistprob+totnullprob);
+
+ if (distortion_mean && iter >0) //then update the mean
+ DistMean=totwdist/totdistprob;
+
+
+ //Variance
+ if (distortion_var && iter >0){
+ double totwdeltadist=0;
+ for (int s=0;s<srcdata->numdoc();s++)
+ for (int i=1;i<trgdata->doclen(s);i++) //exclude i=0!
+ for (int j=0;j<srcdata->doclen(s);j++)
+ if (abs(i-j-1) <= distortion_window){
+ delta=Delta(i,j,trgdata->doclen(s),srcdata->doclen(s));
+ for (int n=0;n<TM[trgdata->docword(s,i)].n;n++)
+ if (A[s][i][n][j]>0)
+ totwdeltadist+=A[s][i][n][j] * (delta-DistMean) * (delta-DistMean);
+
+ }
+
+ DistVar=totwdeltadist/totdistprob;
+ }
+
+ cerr << "Dist: " << DistMean << " " << DistVar << "\n";
+
+ if (use_null_word)
+ cerr << "NullProb: " << NullProb << "\n";
+
+ if (use_beta_distortion){
+ cerr << "Beta A: " << DistA << " Beta B: " << DistB << "\n";
+ EstimateBeta(DistA,DistB,DistMean,DistVar);
+ }
+
+ }
+ else{
+ //Maximization step: Mean;
+ for (int s=0;s<srcdata->numdoc();s++)
+ for (int j=0;j<srcdata->doclen(s);j++)
+ for (int i=0;i<trgdata->doclen(s);i++)
+ if ((use_null_word && i==0) || abs(i-j-1) <= distortion_window)
+ for (int n=0;n<TM[trgdata->docword(s,i)].n;n++)
+ if (A[s][i][n][j]>0)
+ TM[trgdata->docword(s,i)].G[n].M[d]+=A[s][i][n][j] * W2V[srcdata->docword(s,j)][d];
+
+ //second pass
+ for (int e=0;e<trgdict->size();e++)
+ for (int n=0;n<TM[e].n;n++)
+ if (Den[e][n]>0)
+ TM[e].G[n].M[d]/=Den[e][n]; //update the mean estimated
+
+ if (train_variances){
+ //Maximization step: Variance;
+
+ for (int s=0;s<srcdata->numdoc();s++)
+ for (int j=0;j<srcdata->doclen(s);j++)
+ for (int i=0;i<trgdata->doclen(s);i++)
+ if ((use_null_word && i==0) || abs(i-j-1) <= distortion_window)
+ for (int n=0;n<TM[trgdata->docword(s,i)].n;n++)
+ if (A[s][i][n][j]>0)
+ TM[trgdata->docword(s,i)].G[n].S[d]+=
+ (A[s][i][n][j] *
+ (W2V[srcdata->docword(s,j)][d]-TM[trgdata->docword(s,i)].G[n].M[d]) *
+ (W2V[srcdata->docword(s,j)][d]-TM[trgdata->docword(s,i)].G[n].M[d])
+ );
+
+ //second pass
+ for (int e=0;e<trgdict->size();e++)
+ for (int n=0;n<TM[e].n;n++)
+ if (Den[e][n]>0){
+ TM[e].G[n].S[d]/=Den[e][n]; //might be too aggressive?
+ if (TM[e].G[n].S[d] < min_variance) TM[e].G[n].S[d]=min_variance; //improves generalization!
+ }
+ }
+ }
+
+}
+
+
+void cswam::expansion(void *argv){
+
+ long long e=(long long) argv;
+ for (int n=0;n<TM[e].n;n++){
+ //get mean of variances
+ float S=0; for (int d=0;d<D;d++) S+=TM[e].G[n].S[d]; S/=D;
+
+ //variance treshold and population threshold
+ float SThresh=5 * min_variance; float eCThresh=10;
+
+ //show large support set and variances that do not reduce: more aggressive split
+ if ((S/TM[e].G[n].mS) >= 0.95 && //mean variance does not reduce significantly
+ TM[e].G[n].eC >= eCThresh && //population is large
+ S > SThresh) { //variance is large
+ if (verbosity)
+ cerr << "\n" << trgdict->decode(e) << " n= " << n << " (" << TM[e].n << ") Counts: "
+ << TM[e].G[n].eC << " mS: " << S << "\n";
+
+ //expand: create new Gaussian after Gaussian n
+ Gaussian *nG=new Gaussian[TM[e].n+1];
+ float *nW=new float[TM[e].n+1];
+ memcpy((void *)nG,(const void *)TM[e].G, (n+1) * sizeof(Gaussian));
+ memcpy((void *)nW,(const void *)TM[e].W, (n+1) * sizeof(float));
+ if (n+1 < TM[e].n){
+ memcpy((void *)&nG[n+2],(const void*)&TM[e].G[n+1],(TM[e].n-n-1) * sizeof(Gaussian));
+ memcpy((void *)&nW[n+2],(const void*)&TM[e].W[n+1],(TM[e].n-n-1) * sizeof(float));
+ }
+ //initialize mean and variance vectors
+ nG[n+1].M=new float[D];nG[n+1].S=new float[D];
+ for (int d=0;d<D;d++){ //assign new means, keep old variances
+
+ nG[n+1].M[d]=nG[n].M[d] + 2 * sqrt(nG[n].S[d]);
+ nG[n].M[d]=nG[n].M[d] - 2 * sqrt(nG[n].S[d]);
+
+ nG[n+1].S[d]=nG[n].S[d]=(2 * nG[n].S[d]); //enlarge a bit the variance: maybe better increase
+ }
+ nG[n+1].eC=nG[n].eC;
+ nG[n+1].mS=nG[n].mS=S;
+
+ //initialize weight vectors uniformly over n and n+1
+ nW[n+1]=nW[n]/2;nW[n]=nW[n]/2;
+
+ //update TM[e] structure
+ TM[e].n++;
+ delete [] TM[e].G;TM[e].G=nG;
+ delete [] TM[e].W; TM[e].W=nW;
+
+ //we increment loop variable by 1
+ n++;
+ }else{
+ TM[e].G[n].mS=S;
+ }
+
+ }
+
+}
+
+float rl2(const float* a,const float*b, int d){
+ float dist=0;
+ float norm=0;
+ for (int i=0;i<d;i++){
+ dist=(a[i]-b[i])*(a[i]-b[i]);
+ norm=a[i]*a[i];
+ }
+ return (norm>0?dist/norm:1);
+}
+
+float rl1(const float* a,const float*b, int d){
+ float maxreldist=0; float reldist;
+ for (int i=0;i<d;i++){
+ reldist=(abs(a[i]-b[i])/a[i]);
+ if (reldist>maxreldist) maxreldist=reldist;
+ }
+ return maxreldist;
+}
+
+float al1(const float* a,const float*b, int d){
+ float maxdist=0; float dist;
+ for (int i=0;i<d;i++){
+ dist=abs(a[i]-b[i]);
+ if (dist>maxdist) maxdist=dist;
+ }
+ return maxdist;
+}
+
+void cswam::contraction(void *argv){
+
+ long long e=(long long) argv;
+
+ float min_std=sqrt(min_variance);
+ float min_weight=0.01;
+
+ for (int n=0;n<TM[e].n;n++){
+ int n1=0;
+ // look if the component overlaps with some of the previous ones
+ float max_dist=1;
+ for (n1=0;n1<n;n1++) if ((max_dist=al1(TM[e].G[n].M,TM[e].G[n1].M,D))< min_std) break;
+
+ //remove insignificant and overlapping gaussians (relative distance below minimum variance
+ if (TM[e].W[n] < min_weight || max_dist < min_std) { //eliminate this component
+ assert(TM[e].n>1);
+ if (verbosity) cerr << "\n" << trgdict->decode(e) << " n= " << n << " Weight: " << TM[e].W[n] << " Dist= " << max_dist << "\n";
+ //expand: create new mixture model with n-1 components
+ Gaussian *nG=new Gaussian[TM[e].n-1];
+ float *nW=new float[TM[e].n-1];
+ if (n>0){ //copy all entries before n
+ memcpy((void *)nG,(const void *)TM[e].G, n * sizeof(Gaussian));
+ memcpy((void *)nW,(const void *)TM[e].W, n * sizeof(float));
+ }
+ if (n+1 < TM[e].n){ //copy all entries after
+ memcpy((void *)&nG[n],(const void*)&TM[e].G[n+1],(TM[e].n-n-1) * sizeof(Gaussian));
+ memcpy((void *)&nW[n],(const void*)&TM[e].W[n+1],(TM[e].n-n-1) * sizeof(float));
+ }
+
+ //don't need to normalized weights!
+ if (max_dist < min_std)// this is the gaussian overlapping case
+ nW[n1]+=TM[e].W[n]; //the left gaussian inherits the weight
+
+ //update TM[e] structure
+ TM[e].n--;n--;
+ delete [] TM[e].G;TM[e].G=nG;
+ delete [] TM[e].W; TM[e].W=nW;
+ }
+ }
+
+ //re-normalize weights
+ float totw=0;
+ for (int n=0;n<TM[e].n;n++){totw+=TM[e].W[n]; assert(TM[e].W[n] > 0.0001);}
+ for (int n=0;n<TM[e].n;n++){TM[e].W[n]/=totw;};
+}
+
+int cswam::train(char *srctrainfile, char*trgtrainfile,char *modelfile, int maxiter,int threads){
+
+ //initialize model
+ initModel(modelfile); //this might change the dictionary!
+
+ //Load training data
+
+ srcdata=new doc(srcdict,srctrainfile);
+ trgdata=new doc(trgdict,trgtrainfile,use_null_word); //use null word
+
+
+ iter=0;
+
+ cerr << "Starting training";
+ threadpool thpool=thpool_init(threads);
+ int numtasks=trgdict->size()>trgdata->numdoc()?trgdict->size():trgdata->numdoc();
+ task *t=new task[numtasks];
+ assert(numtasks>D); //multi-threading also distributed over D
+
+
+ //support variable to compute likelihood
+ localLL=new float[srcdata->numdoc()];
+
+ while (iter < maxiter){
+
+ cerr << "\nIteration: " << ++iter << "\n";
+
+ initAlphaDen();
+
+ //reset support set size
+ for (int e=0;e<trgdict->size();e++)
+ for (int n=0;n<TM[e].n;n++) TM[e].G[n].eC=0; //will be updated in E-step
+
+
+ cerr << "E-step: ";
+ //compute expected counts in each single sentence
+ for (long long s=0;s<srcdata->numdoc();s++){
+ //prepare and assign tasks to threads
+ t[s].ctx=this; t[s].argv=(void *)s;
+ thpool_add_work(thpool, &cswam::expected_counts_helper, (void *)&t[s]);
+
+ }
+ //join all threads
+ thpool_wait(thpool);
+
+
+ //Reset model before update
+ for (int e=0;e <trgdict->size();e++)
+ for (int n=0;n<TM[e].n;n++){
+ memset(TM[e].G[n].M,0,D * sizeof (float));
+ if (train_variances)
+ memset(TM[e].G[n].S,0,D * sizeof (float));
+ }
+
+ for (int e=0;e<trgdict->size();e++)
+ memset(Den[e],0,TM[e].n * sizeof(float));
+
+ cswam_LL=0; //compute LL of current model
+ //compute normalization term for each target word
+ for (int s=0;s<srcdata->numdoc();s++){
+ cswam_LL+=localLL[s];
+ for (int i=0;i<trgdata->doclen(s);i++)
+ for (int n=0;n<TM[trgdata->docword(s,i)].n;n++)
+ for (int j=0;j<srcdata->doclen(s);j++)
+ Den[trgdata->docword(s,i)][n]+=A[s][i][n][j];
+ }
+
+ cerr << "LL = " << cswam_LL << "\n";
+
+
+ cerr << "M-step: ";
+ for (long long d=0;d<=D;d++){ //include special job d=D for distortion model
+ t[d].ctx=this; t[d].argv=(void *)d;
+ thpool_add_work(thpool, &cswam::maximization_helper, (void *)&t[d]);
+ }
+
+ //join all threads
+ thpool_wait(thpool);
+
+ //some checks of the models: fix degenerate models
+ for (int e=0;e<trgdict->size();e++)
+ if (e!=trgEoD)
+ for (int n=0;n<TM[e].n;n++)
+ if (!Den[e][n]){
+ if (verbosity)
+ cerr << "\nRisk of degenerate model. Word: " << trgdict->decode(e) << " n: " << n << " eC:" << TM[e].G[n].eC << "\n";
+ for (int d=0;d<D;d++) TM[e].G[n].S[d]=SSEED * min_variance;
+ }
+
+
+// if (trgdict->encode("bege")==e){
+// cerr << "bege " << " mS: " << TM[e].G[0].mS << " n: " << TM[e].n << " eC " << TM[e].G[0].eC << "\n";
+// cerr << "M:"; for (int d=0;d<10;d++) cerr << " " << TM[e].G[0].M[d]; cerr << "\n";
+// cerr << "S:"; for (int d=0;d<10;d++) cerr << " " << TM[e].G[0].S[d]; cerr << "\n";
+// }
+// }
+
+ //update the weight estimates: ne need of multithreading
+ float totW; int ngauss=0;
+ for (int e=0;e<trgdict->size();e++){
+ totW=0;
+ for (int n=0;n<TM[e].n;n++){ totW+=Den[e][n]; ngauss++;}
+ if (totW>0)
+ for (int n=0;n<TM[e].n;n++) TM[e].W[n]=Den[e][n]/totW;
+ }
+ cerr << "Num Gaussians: " << ngauss << "\n";
+
+ if (iter > 1 || incremental_train ){
+
+ freeAlphaDen(); //needs to be reallocated as models might change
+
+ cerr << "\nP-step: ";
+ for (long long e=0;e<trgdict->size();e++){
+ //check if to decrease number of gaussians per target word
+ t[e].ctx=this; t[e].argv=(void *)e;
+ thpool_add_work(thpool, &cswam::contraction_helper, (void *)&t[e]);
+ }
+ //join all threads
+ thpool_wait(thpool);
+
+ cerr << "\nS-step: ";
+ for (long long e=0;e<trgdict->size();e++){
+ //check if to increase number of gaussians per target word
+ t[e].ctx=this; t[e].argv=(void *)e;
+ thpool_add_work(thpool, &cswam::expansion_helper, (void *)&t[e]);
+ }
+ //join all threads
+ thpool_wait(thpool);
+
+
+ }
+
+
+ if (srcdata->numdoc()>10000) system("date");
+
+ saveModel(modelfile);
+
+ }
+
+ // for (int e=0;e<trgdict->size();e++)
+ // for (int d=0;d<D;d++)
+ // cout << trgdict->decode(e) << " S: " << S[e][d] << " M: " << M[e][d]<< "\n";
+
+ //destroy thread pool
+ thpool_destroy(thpool);
+
+ freeAlphaDen();
+
+ delete srcdata; delete trgdata;
+ delete [] t; delete [] localLL;
+
+ return 1;
+}
+
+
+
+void cswam::aligner(void *argv){
+ long long s=(long long) argv;
+ static float maxfloat=std::numeric_limits<float>::max();
+
+
+ if (! (s % 10000)) {cerr << ".";cerr.flush();}
+ //fprintf(stderr,"Thread: %lu Document: %d (out of %d)\n",(long)pthread_self(),s,srcdata->numdoc());
+
+ int trglen=trgdata->doclen(s); // length of target sentence
+ int srclen=srcdata->doclen(s); //length of source sentence
+
+ assert(trglen<MAX_LINE);
+
+ //Viterbi alignment: find the most probable alignment for source
+ float score; float best_score;int best_i;float sum=0;
+
+ bool some_not_null=false; int first_target=0;
+
+ for (int j=0;j<srclen;j++){
+ //cout << "src: " << srcdict->decode(srcdata->docword(s,j)) << "\n";
+
+ best_score=-maxfloat;best_i=0;
+
+ for (int i=first_target;i<trglen;i++)
+ if ((use_null_word && i==0) || abs(i-j-1) <= distortion_window){
+ //cout << "tgt: " << trgdict->decode(trgdata->docword(s,i)) << " ";
+
+ for (int n=0;n<TM[trgdata->docword(s,i)].n;n++){
+ score=LogGauss(D,
+ W2V[srcdata->docword(s,j)],
+ TM[trgdata->docword(s,i)].G[n].M,
+ TM[trgdata->docword(s,i)].G[n].S)+log(TM[trgdata->docword(s,i)].W[n]);
+ if (n==0) sum=score;
+ else sum=logsum(sum,score);
+ } //completed mixture score
+
+ if (distortion_var || distortion_mean){
+ if (i>0 ||!use_null_word){
+ float d=Delta(i,j,trglen,srclen);
+ sum+=logf(1-NullProb) + LogDistortion(d);
+ }
+ else
+ if (use_null_word ) sum+=logf(NullProb);
+ }
+ else //use plain distortion model
+ if (i>0){
+ if (i - (use_null_word?1:0) > j )
+ sum-=(i- (use_null_word?1:0) -j);
+ else if (i - (use_null_word?1:0) < j )
+ sum-=(j - i + (use_null_word?1:0));
+ }
+ //add distortion score now
+
+ //cout << "score: " << sum << "\n";
+ // cout << "\t " << srcdict->decode(srcdata->docword(s,j)) << " " << dist << "\n";
+ //if (dist > -50) score=(float)exp(-dist)/norm;
+ if (sum > best_score){
+ best_score=sum;
+ best_i=i;
+ if ((!use_null_word || best_i>0) && !some_not_null) some_not_null=true;
+ }
+ }
+ //cout << "best_i: " << best_i << "\n";
+
+ alignments[s % bucket][j]=best_i;
+
+ if (j==(srclen-1) && !some_not_null){
+ j=-1; //restart loop and remove null word from options
+ first_target=1;
+ some_not_null=true; //make sure to pass this check next time
+ }
+ }
+
+}
+
+
+int cswam::test(char *srctestfile, char *trgtestfile, char* modelfile, char* alignfile,int threads){
+
+ {mfstream out(alignfile,ios::out);} //empty the file
+
+ initModel(modelfile);
+
+ if (!distortion_mean){
+ if (use_beta_distortion){
+ cerr << "ERROR: cannot test with beta distribution without mean\n";
+ return 0;
+ }
+ DistMean=0; //force mean to zero
+ }
+
+ //Load training data
+ srcdata=new doc(srcdict,srctestfile);
+ trgdata=new doc(trgdict,trgtestfile,use_null_word);
+ assert(srcdata->numdoc()==trgdata->numdoc());
+
+
+ bucket=BUCKET; //initialize the bucket size
+
+ alignments=new int* [BUCKET];
+ for (int s=0;s<BUCKET;s++)
+ alignments[s]=new int[MAX_LINE];
+
+ threadpool thpool=thpool_init(threads);
+ task *t=new task[BUCKET];
+
+ cerr << "Start alignment\n";
+
+ for (long long s=0;s<srcdata->numdoc();s++){
+
+ t[s % BUCKET].ctx=this; t[s % BUCKET].argv=(void *)s;
+ thpool_add_work(thpool, &cswam::aligner_helper, (void *)&t[s % BUCKET]);
+
+
+ if (((s % BUCKET) == (BUCKET-1)) || (s==(srcdata->numdoc()-1)) ){
+ //join all threads
+ thpool_wait(thpool);
+
+ //cerr << "Start printing\n";
+
+ if ((s % BUCKET) != (BUCKET-1))
+ bucket=srcdata->numdoc() % bucket; //last bucket at end of file
+
+ mfstream out(alignfile,ios::out | ios::app);
+
+ for (int b=0;b<bucket;b++){ //includes the eof case of
+ //out << "Sentence: " << s-bucket+1+b;
+ bool first=true;
+ for (int j=0; j<srcdata->doclen(s-bucket+1+b); j++)
+ if (!use_null_word || alignments[b][j]>0){
+ //print target using 0 for first actual word
+ out << (first?"":" ") << j << "-" << alignments[b][j]-(use_null_word?1:0);
+ first=false;
+ }
+ out << "\n";
+ }
+ }
+
+ }
+
+
+ //destroy thread pool
+ thpool_destroy(thpool);
+
+ delete [] t;
+ for (int s=0;s<BUCKET;s++) delete [] alignments[s];delete [] alignments;
+ delete srcdata; delete trgdata;
+ return 1;
+}
+
+//find for each target word a list of associated source words
+
+typedef std::pair <int,float> mientry; //pair type containing src word and mi score
+bool myrank (Friend a,Friend b) { return (a.score > b.score ); }
+
+
+//void cswam::findfriends(FriendList* friends){
+//
+// typedef std::unordered_map<int, int> src_map;
+// src_map* table= new src_map[trgdict->size()];
+//
+// // amap["def"][7] = 2.2;
+// // std::cout << amap["abc"][12] << '\n';
+// // std::cout << amap["def"][7] << '\n';
+//
+//
+// int *srcfreq=new int[srcdict->size()];
+// int *trgfreq=new int[trgdict->size()];
+// int totfreq=0;
+// int minfreq=10;
+//
+// cerr << "collecting co-occurrences\n";
+// for (int s=0;s<srcdata->numdoc();s++){
+//
+// int trglen=trgdata->doclen(s); // length of target sentence
+// int srclen=srcdata->doclen(s); //length of source sentence
+//
+// int frac=(s * 1000)/srcdata->numdoc();
+// if (!(frac % 10)) fprintf(stderr,"%02d\b\b",frac/10);
+//
+// for (int i=0;i<trglen;i++){
+// int trg=trgdata->docword(s,i);
+// float trgdictfreq=trgdict->freq(trg);
+// if (trgdict->freq(trg)>=10){
+// for (int j=0;j<srclen;j++){
+// int src=srcdata->docword(s,j);
+// float freqratio=srcdict->freq(src)/trgdictfreq;
+// if (srcdict->freq(src)>=minfreq && freqratio <= 10 && freqratio >= 0.1){
+// table[trg][src]++;
+// totfreq++;
+// srcfreq[src]++;
+// trgfreq[trg]++;
+// }
+// }
+// }
+// }
+// }
+//
+// cerr << "computing mutual information\n";
+// Friend mie; FriendList mivec;
+//
+//
+// for (int i = 0; i < trgdict->size(); i++){
+//
+// int frac=(i * 1000)/trgdict->size();
+// if (!(frac % 10)) fprintf(stderr,"%02d\b\b",frac/10);
+//
+// mivec.clear();
+// for (auto jtr = table[i].begin(); jtr != table[i].end();jtr++){
+// int j=(*jtr).first; int freq=(*jtr).second;
+// float freqratio=(float)srcdict->freq(j)/(float)trgdict->freq(i);
+// if (freq>minfreq){ // && freqratio < 10 && freqratio > 0.1){
+// //compute mutual information
+// float mutualinfo=
+// logf(freq/(float)trgfreq[i]) - log((float)srcfreq[j]/totfreq);
+// mutualinfo/=log(2);
+// mie.word=j; mie.score=mutualinfo;
+// mivec.push_back(mie);
+// }
+// }
+// if (mivec.size()>0){
+// std::sort(mivec.begin(),mivec.end(),myrank);
+// //sort the vector and take the top log(10)
+// int count=0;
+// for (auto jtr = mivec.begin(); jtr != mivec.end();jtr++){
+// //int j=(*jtr).word; float mutualinfo=(*jtr).score;
+// friends[i].push_back(*jtr);
+// //cout << trgdict->decode(i) << " " << srcdict->decode(j) << " " << mutualinfo << endl;
+// //if (++count >=50) break;
+// }
+//
+// }
+// }
+//
+//
+//}
+
+
+
+void cswam::M1_ecounts(void *argv){
+ long long s=(long long) argv;
+
+ int b=s % threads; //value of the actual bucket
+ int trglen=trgdata->doclen(s); // length of target sentence
+ int srclen=srcdata->doclen(s); //length of source sentence
+ float pef=0;
+
+ ShowProgress(s,srcdata->numdoc());
+
+ float lowprob=0.0000001;
+
+ for (int j=0;j<srclen;j++){
+ int f=srcdata->docword(s,j);
+ if (srcdict->freq(f)>=minfreq){
+ float t=0;
+ for (int i=0;i<trglen;i++){
+ int e=trgdata->docword(s,i);
+ if (trgdict->freq(e)>=minfreq && (i==0 || abs(i-j-1) <= distortion_window) && prob[e][f]>lowprob)
+ t+=prob[e][f];
+ }
+ for (int i=0;i<trglen;i++){
+ int e=trgdata->docword(s,i);
+ if (trgdict->freq(e)>=minfreq && (i==0 || abs(i-j-1) <= distortion_window) && prob[e][f]>lowprob){
+ pef=prob[e][f]/t;
+ loc_efcounts[b][e][f]+=pef;
+ loc_ecounts[b][e]+=pef;
+ }
+ }
+ }
+ }
+
+}
+
+void cswam::M1_update(void *argv){
+ long long e=(long long) argv;
+
+ ShowProgress(e,trgdict->size());
+
+// for (auto jtr = efcounts[e].begin(); jtr != efcounts[e].end();jtr++){
+ for (src_map::iterator jtr = efcounts[e].begin(); jtr != efcounts[e].end();jtr++){
+ int f=(*jtr).first;
+ prob[e][f]=efcounts[e][f]/ecounts[e];
+ }
+}
+
+void cswam::M1_collect(void *argv){
+ long long e=(long long) argv;
+
+ ShowProgress(e,trgdict->size());
+
+ for (int b=0;b<threads;b++){
+ ecounts[e]+=loc_ecounts[b][e];
+ loc_ecounts[b][e]=0; //reset local count
+// for (auto jtr = loc_efcounts[b][e].begin(); jtr != loc_efcounts[b][e].end();jtr++){
+ for (src_map::iterator jtr = loc_efcounts[b][e].begin(); jtr != loc_efcounts[b][e].end();jtr++){
+ int f=(*jtr).first;
+ efcounts[e][f]+=loc_efcounts[b][e][f];
+ }
+ loc_efcounts[b][e].clear(); //reset local counts
+ }
+}
+
+
+void cswam::M1_clearcounts(bool clearmem){
+
+ if (efcounts==NULL){
+ cerr << "allocating thread local structures\n";
+ //allocate thread safe structures
+ loc_efcounts=new src_map*[threads];
+ loc_ecounts=new float*[threads];
+ for (int b=0;b<threads;b++){
+ loc_efcounts[b]=new src_map[trgdict->size()];
+ loc_ecounts[b]=new float[trgdict->size()];
+ }
+ cerr << "allocating global count structures\n";
+ //allocate the global count structures
+ efcounts=new src_map[trgdict->size()];
+ ecounts=new float[trgdict->size()];
+ }
+
+
+ if (clearmem){
+ for (int b=0;b<threads;b++){
+ delete [] loc_efcounts[b];
+ delete [] loc_ecounts[b];
+ }
+ delete [] loc_efcounts; delete [] loc_ecounts;
+ delete [] efcounts; delete [] ecounts;
+ }else{
+ // cerr << "resetting expected counts\n";
+ for (int e = 0; e < trgdict->size(); e++){
+ efcounts[e].clear();
+ memset(ecounts,0,sizeof(int)*trgdict->size());
+ }
+ //local expected counts are reset in main loop
+ }
+
+}
+
+
+void cswam::findfriends(FriendList* friends){
+
+ //allocate the global prob table
+ prob= new src_map[trgdict->size()];
+
+ //allocate thread safe structures
+ M1_clearcounts(false);
+
+ //prepare thread pool
+ threadpool thpool=thpool_init(threads);
+ task *t=new task[trgdict->size()>threads?trgdict->size():threads];
+
+ float minprob=0.01;
+
+ cerr << "initializing M1\n";
+ for (int s=0;s<srcdata->numdoc();s++){
+ int trglen=trgdata->doclen(s); // length of target sentence
+ int srclen=srcdata->doclen(s); //length of source sentence
+
+ int frac=(s * 1000)/srcdata->numdoc();
+ if (!(frac % 10)) fprintf(stderr,"%02d\b\b",frac/10);
+
+ for (int j=0;j<srclen;j++){
+ int f=srcdata->docword(s,j);
+ if (srcdict->freq(f)>=minfreq){
+ for (int i=0;i<trglen;i++){
+ int e=trgdata->docword(s,i);
+ if (trgdict->freq(e)>=minfreq && (i==0 || abs(i-j-1) <= distortion_window))
+ prob[e][f]=1;
+ }
+ }
+ }
+ }
+
+ cerr << "training M1\n";
+ for (int it=0;it<M1iter;it++){
+
+ cerr << "it: " << it+1;
+ M1_clearcounts(false);
+
+ //compute expected counts
+ for (long long s=0;s<srcdata->numdoc();s++){
+
+ t[s % threads].ctx=this; t[s % threads].argv=(void *)s;
+ thpool_add_work(thpool, &cswam::M1_ecounts_helper,(void *)&t[s % threads]);
+
+ if (((s % threads) == (threads-1)) || (s==(srcdata->numdoc()-1)))
+ thpool_wait(thpool);//join all threads
+ }
+
+ //update the global counts
+ for (long long e = 0; e < trgdict->size(); e++){
+ t[e].ctx=this; t[e].argv=(void *)e;
+ thpool_add_work(thpool, &cswam::M1_collect_helper,(void *)&t[e]);
+ }
+ thpool_wait(thpool);//join all threads
+
+ //update probabilities
+ for (long long e = 0; e < trgdict->size(); e++){
+ t[e].ctx=this; t[e].argv=(void *)e;
+ thpool_add_work(thpool, &cswam::M1_update_helper,(void *)&t[e]);
+ }
+
+ thpool_wait(thpool); //join all threads
+ }
+
+ cerr << "computing candidates\n";
+ Friend f;FriendList fv;
+
+ for (int e = 0; e < trgdict->size(); e++){
+
+ ShowProgress(e,trgdict->size());
+
+ fv.clear();
+ //save in a vector and compute entropy
+ float H=0;
+// for (auto jtr = prob[e].begin(); jtr != prob[e].end();jtr++){
+ for (src_map::iterator jtr = prob[e].begin(); jtr != prob[e].end();jtr++){
+ f.word=(*jtr).first; f.score=(*jtr).second;
+ assert(f.score>=0 && f.score<=1);
+ if (f.score>0)
+ H-=f.score * logf(f.score);
+ if (f.score >= minprob) //never include options with prob < minprob
+ fv.push_back(f);
+ }
+
+ std::sort(fv.begin(),fv.end(),myrank);
+ int PP=round(expf(H)); //compute perplexity
+
+ cout << trgdict->decode(e) << " # friends: " << fv.size() << " PP " << PP << endl;
+ int count=0;
+// for (auto jtr = fv.begin(); jtr != fv.end();jtr++){
+ for (FriendList::iterator jtr = fv.begin(); jtr != fv.end();jtr++){
+ friends[e].push_back(*jtr);
+ //if (verbosity)
+ cout << trgdict->decode(e) << " " << srcdict->decode((*jtr).word) << " " << (*jtr).score << endl;
+ if (++count >= PP) break;
+ }
+ }
+
+ //destroy thread pool
+ thpool_destroy(thpool); delete [] t;
+
+ M1_clearcounts(true);
+
+ delete [] prob;
+}
+
+} //namespace irstlm
+
diff --git a/src/cswam.h b/src/cswam.h
new file mode 100755
index 0000000..a77647f
--- /dev/null
+++ b/src/cswam.h
@@ -0,0 +1,219 @@
+/******************************************************************************
+ IrstLM: IRST Language Model Toolkit, compile LM
+ Copyright (C) 2006 Marcello Federico, ITC-irst Trento, Italy
+
+ This library is free software; you can redistribute it and/or
+ modify it under the terms of the GNU Lesser General Public
+ License as published by the Free Software Foundation; either
+ version 2.1 of the License, or (at your option) any later version.
+
+ This library is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+ Lesser General Public License for more details.
+
+ You should have received a copy of the GNU Lesser General Public
+ License along with this library; if not, write to the Free Software
+ Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+
+ ******************************************************************************/
+#ifndef MF_CSWAM_H
+#define MF_CSWAM_H
+
+#ifdef HAVE_CXX0
+#include <unordered_map>
+#else
+#include <map>
+#endif
+
+#include <vector>
+
+namespace irstlm {
+
+typedef struct{
+
+ float* M; //mean vectors
+ float* S; //variance vectors
+ //training support items
+ float eC; //support set size
+ float mS; //mean variance
+
+} Gaussian;
+
+typedef struct{
+ int n; //number of Gaussians
+ float *W; //weight vector
+ Gaussian *G; //Gaussians
+} TransModel;
+
+typedef struct{
+ int word; //word code
+ float score; //score (mutual information)
+} Friend;
+
+typedef std::vector<Friend> FriendList; //list of word Friends
+#ifdef HAVE_CXX0
+typedef std::unordered_map<int,float> src_map; //target to source associative memory
+#else
+typedef std::map<int,float> src_map; //target to source associative memory
+#endif
+
+class cswam {
+
+ //data
+ dictionary* srcdict; //source dictionary
+ dictionary* trgdict; //target dictionary
+ doc* srcdata; //source training data
+ doc* trgdata; //target trainign data
+ FriendList* friends; //prior list of translation candidates
+
+ //word2vec
+ float **W2V; //vector for each source word
+ int D; //dimension of vector space
+
+
+ //model
+ TransModel *TM;
+ float DistMean,DistVar; //distortion mean and variance
+ float DistA,DistB; //gamma parameters
+ float NullProb; //null probability
+
+ //settings
+ bool normalize_vectors;
+ bool train_variances;
+ double fix_null_prob;
+ bool use_null_word;
+ bool verbosity;
+ float min_variance;
+ int distortion_window;
+ bool distortion_mean;
+ bool distortion_var;
+ bool use_beta_distortion;
+ int minfreq;
+ bool incremental_train;
+
+ //private info shared among threads
+ int trgBoD; //code of segment begin in target dict
+ int trgEoD; //code of segment end in target dict
+ int srcBoD; //code of segment begin in src dict
+ int srcEoD; //code of segment end in src dict
+
+ float ****A; //expected counts
+ float **Den; //alignment probs
+ float *localLL; //local log-likelihood
+ int **alignments; //word alignment info
+ int threads; //number of threads
+ int bucket; //size of bucket
+ int iter; //current iteration
+ int M1iter; //iterations with model 1
+
+ //Model 1 initialization private variables
+
+ src_map* prob; //model one probabilities
+ src_map** loc_efcounts; //expected count probabilities
+ float **loc_ecounts; //expected count probabilities
+ src_map* efcounts; //expected count probabilities
+ float *ecounts; //expected count probabilities
+
+ struct task { //basic task info to run task
+ void *ctx;
+ void *argv;
+ };
+
+
+public:
+
+ cswam(char* srcdatafile,char* trgdatafile, char* word2vecfile,
+ bool forcemodel,
+ bool usenull,double fix_null_prob,
+ bool normv2w,
+ int model1iter,
+ bool trainvar,float minvar,
+ int distwin,bool distbeta, bool distmean,bool distvar,
+ bool verbose);
+
+ ~cswam();
+
+ void loadword2vec(char* fname);
+ void randword2vec(const char* word,float* vec,int it=0);
+ void initModel(char* fname);
+ void initEntry(int entry);
+ int saveModel(char* fname);
+ int saveModelTxt(char* fname);
+ int loadModel(char* fname,bool expand=false);
+
+ void initAlphaDen();
+ void freeAlphaDen();
+
+
+ float LogGauss(const int dim,const float* x,const float *m, const float *s);
+
+ float LogDistortion(float d);
+ float LogBeta(float x, float a, float b);
+ void EstimateBeta(float &a, float &b, float m, float s);
+
+ float Delta( int i, int j, int l=1, int m=1);
+
+ void expected_counts(void *argv);
+ static void *expected_counts_helper(void *argv){
+ task t=*(task *)argv;
+ ((cswam *)t.ctx)->expected_counts(t.argv);return NULL;
+ };
+
+ void maximization(void *argv);
+ static void *maximization_helper(void *argv){
+ task t=*(task *)argv;
+ ((cswam *)t.ctx)->maximization(t.argv);return NULL;
+ };
+
+ void expansion(void *argv);
+ static void *expansion_helper(void *argv){
+ task t=*(task *)argv;
+ ((cswam *)t.ctx)->expansion(t.argv);return NULL;
+ };
+
+ void contraction(void *argv);
+ static void *contraction_helper(void *argv){
+ task t=*(task *)argv;
+ ((cswam *)t.ctx)->contraction(t.argv);return NULL;
+ };
+
+
+ void M1_ecounts(void *argv);
+ static void *M1_ecounts_helper(void *argv){
+ task t=*(task *)argv;
+ ((cswam *)t.ctx)->M1_ecounts(t.argv);return NULL;
+ }
+
+ void M1_collect(void *argv);
+ static void *M1_collect_helper(void *argv){
+ task t=*(task *)argv;
+ ((cswam *)t.ctx)->M1_collect(t.argv);return NULL;
+ }
+
+ void M1_update(void *argv);
+ static void *M1_update_helper(void *argv){
+ task t=*(task *)argv;
+ ((cswam *)t.ctx)->M1_update(t.argv);return NULL;
+ }
+
+ void M1_clearcounts(bool clearmem=false);
+
+ void findfriends(FriendList* friends);
+
+
+
+ int train(char *srctrainfile,char *trgtrainfile,char* modelfile, int maxiter,int threads=1);
+
+ void aligner(void *argv);
+ static void *aligner_helper(void *argv){
+ task t=*(task *)argv;
+ ((cswam *)t.ctx)->aligner(t.argv);return NULL;
+ };
+
+
+ int test(char *srctestfile, char* trgtestfile, char* modelfile,char* alignmentfile, int threads=1);
+
+};
+} //namespace irstlm
+#endif
diff --git a/src/dict.cpp b/src/dict.cpp
new file mode 100644
index 0000000..43bb8aa
--- /dev/null
+++ b/src/dict.cpp
@@ -0,0 +1,164 @@
+// $Id: dict.cpp 3677 2010-10-13 09:06:51Z bertoldi $
+
+
+#include <iostream>
+#include "cmd.h"
+#include "util.h"
+#include "mfstream.h"
+#include "mempool.h"
+#include "dictionary.h"
+
+using namespace std;
+
+
+void print_help(int TypeFlag=0){
+ std::cerr << std::endl << "dict - extracts a dictionary" << std::endl;
+ std::cerr << std::endl << "USAGE:" << std::endl;
+ std::cerr << " dict -i=<inputfile> [options]" << std::endl;
+ std::cerr << std::endl << "DESCRIPTION:" << std::endl;
+ std::cerr << " dict extracts a dictionary from a corpus or a dictionary." << std::endl;
+ std::cerr << std::endl << "OPTIONS:" << std::endl;
+ FullPrintParams(TypeFlag, 0, 1, stderr);
+}
+
+void usage(const char *msg = 0)
+{
+ if (msg){
+ std::cerr << msg << std::endl;
+ }
+ else{
+ print_help();
+ }
+}
+
+int main(int argc, char **argv)
+{
+ char *inp=NULL;
+ char *out=NULL;
+ char *testfile=NULL;
+ char *intsymb=NULL; //must be single characters
+ int freqflag=0; //print frequency of words
+ int sortflag=0; //sort dictionary by frequency
+ int curveflag=0; //plot dictionary growth curve
+ int curvesize=10; //size of curve
+ int listflag=0; //print oov words in test file
+ int size=1000000; //initial size of table ....
+ float load_factor=0; //initial load factor, default LOAD_FACTOR
+
+ int prunefreq=0; //pruning according to freq value
+ int prunerank=0; //pruning according to freq rank
+
+ bool help=false;
+
+ DeclareParams((char*)
+ "InputFile", CMDSTRINGTYPE|CMDMSG, &inp, "input file (Mandatory)",
+ "i", CMDSTRINGTYPE|CMDMSG, &inp, "input file (Mandatory)",
+ "OutputFile", CMDSTRINGTYPE|CMDMSG, &out, "output file",
+ "o", CMDSTRINGTYPE|CMDMSG, &out, "output file",
+ "f", CMDBOOLTYPE|CMDMSG, &freqflag,"output word frequencies; default is false",
+ "Freq", CMDBOOLTYPE|CMDMSG, &freqflag,"output word frequencies; default is false",
+ "sort", CMDBOOLTYPE|CMDMSG, &sortflag,"sort dictionary by frequency; default is false",
+ "Size", CMDINTTYPE|CMDMSG, &size, "Initial dictionary size; default is 1000000",
+ "s", CMDINTTYPE|CMDMSG, &size, "Initial dictionary size; default is 1000000",
+ "LoadFactor", CMDFLOATTYPE|CMDMSG, &load_factor, "set the load factor for cache; it should be a positive real value; default is 0",
+ "lf", CMDFLOATTYPE|CMDMSG, &load_factor, "set the load factor for cache; it should be a positive real value; default is 0",
+ "IntSymb", CMDSTRINGTYPE|CMDMSG, &intsymb, "interruption symbol",
+ "is", CMDSTRINGTYPE|CMDMSG, &intsymb, "interruption symbol",
+
+ "PruneFreq", CMDINTTYPE|CMDMSG, &prunefreq, "prune words with frequency below the specified value",
+ "pf", CMDINTTYPE|CMDMSG, &prunefreq, "prune words with frequency below the specified value",
+ "PruneRank", CMDINTTYPE|CMDMSG, &prunerank, "prune words with frequency rank above the specified value",
+ "pr", CMDINTTYPE|CMDMSG, &prunerank, "prune words with frequency rank above the specified value",
+
+ "Curve", CMDBOOLTYPE|CMDMSG, &curveflag,"show dictionary growth curve; default is false",
+ "c", CMDBOOLTYPE|CMDMSG, &curveflag,"show dictionary growth curve; default is false",
+ "CurveSize", CMDINTTYPE|CMDMSG, &curvesize, "default 10",
+ "cs", CMDINTTYPE|CMDMSG, &curvesize, "default 10",
+
+ "TestFile", CMDSTRINGTYPE|CMDMSG, &testfile, "compute OOV rates on the specified test corpus",
+ "t", CMDSTRINGTYPE|CMDMSG, &testfile, "compute OOV rates on the specified test corpus",
+ "ListOOV", CMDBOOLTYPE|CMDMSG, &listflag, "print OOV words to stderr; default is false",
+ "oov", CMDBOOLTYPE|CMDMSG, &listflag, "print OOV words to stderr; default is false",
+
+ "Help", CMDBOOLTYPE|CMDMSG, &help, "print this help",
+ "h", CMDBOOLTYPE|CMDMSG, &help, "print this help",
+
+ (char*)NULL
+ );
+
+ if (argc == 1){
+ usage();
+ exit_error(IRSTLM_NO_ERROR);
+ }
+
+ GetParams(&argc, &argv, (char*) NULL);
+
+ if (help){
+ usage();
+ exit_error(IRSTLM_NO_ERROR);
+ }
+
+ if (inp==NULL) {
+ usage();
+ exit_error(IRSTLM_NO_ERROR, "Warning: no input file specified");
+ };
+
+ // options compatibility issues:
+ if (curveflag && !freqflag)
+ freqflag=1;
+ if (testfile!=NULL && !freqflag) {
+ freqflag=1;
+ mfstream test(testfile,ios::in);
+ if (!test) {
+ usage();
+ std::string msg("Warning: cannot open testfile: ");
+ msg.append(testfile);
+ exit_error(IRSTLM_NO_ERROR, msg);
+ }
+ test.close();
+
+ }
+
+ //create dictionary: generating it from training corpus, or loading it from a dictionary file
+ dictionary *d = new dictionary(inp,size,load_factor);
+
+ // sort dictionary
+ if (prunefreq>0 || prunerank>0 || sortflag) {
+ dictionary *sortd=new dictionary(d,false);
+ sortd->sort();
+ delete d;
+ d=sortd;
+ }
+
+
+ // show statistics on dictionary growth and OOV rates on test corpus
+ if (testfile != NULL)
+ d->print_curve_oov(curvesize, testfile, listflag);
+ if (curveflag)
+ d->print_curve_growth(curvesize);
+
+
+ //prune words according to frequency and rank
+ if (prunefreq>0 || prunerank>0) {
+ cerr << "pruning dictionary prunefreq:" << prunefreq << " prunerank: " << prunerank <<" \n";
+ int count=0;
+ int bos=d->encode(d->BoS());
+ int eos=d->encode(d->EoS());
+
+ for (int i=0; i< d->size() ; i++) {
+ if (prunefreq && d->freq(i) <= prunefreq && i!=bos && i!=eos) {
+ d->freq(i,0);
+ continue;
+ }
+ if (prunerank>0 && count>=prunerank && i!=bos && i!=eos) {
+ d->freq(i,0);
+ continue;
+ }
+ count++;
+ }
+ }
+ // if outputfile is provided, write the dictionary into it
+ if(out!=NULL) d->save(out,freqflag);
+
+}
+
diff --git a/src/dictionary.cpp b/src/dictionary.cpp
new file mode 100644
index 0000000..95e727e
--- /dev/null
+++ b/src/dictionary.cpp
@@ -0,0 +1,577 @@
+// $Id: dictionary.cpp 3640 2010-10-08 14:58:17Z bertoldi $
+
+/******************************************************************************
+ IrstLM: IRST Language Model Toolkit
+ Copyright (C) 2006 Marcello Federico, ITC-irst Trento, Italy
+
+ This library is free software; you can redistribute it and/or
+ modify it under the terms of the GNU Lesser General Public
+ License as published by the Free Software Foundation; either
+ version 2.1 of the License, or (at your option) any later version.
+
+ This library is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+ Lesser General Public License for more details.
+
+ You should have received a copy of the GNU Lesser General Public
+ License along with this library; if not, write to the Free Software
+ Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+
+ ******************************************************************************/
+
+#include <stdio.h>
+#include <stdlib.h>
+#include <iomanip>
+#include <iostream>
+#include <fstream>
+#include <sstream>
+#include "mempool.h"
+#include "htable.h"
+#include "index.h"
+#include "util.h"
+#include "dictionary.h"
+#include "mfstream.h"
+
+using namespace std;
+
+dictionary::dictionary(char *filename,int size, float lf)
+{
+ if (lf<=0.0) lf=DICTIONARY_LOAD_FACTOR;
+ load_factor=lf;
+
+ htb = new HASHTABLE_t((size_t) (size/load_factor));
+ tb = new dict_entry[size];
+ st = new strstack(size * 10);
+
+ for (int i=0; i<size; i++) tb[i].freq=0;
+
+ oov_code = -1;
+
+ n = 0;
+ N = 0;
+ dubv = 0;
+ lim = size;
+ ifl=0; //increment flag
+
+ if (filename==NULL) return;
+
+ mfstream inp(filename,ios::in);
+
+ if (!inp) {
+ std::stringstream ss_msg;
+ ss_msg << "cannot open " << filename << "\n";
+ exit_error(IRSTLM_ERROR_IO, ss_msg.str());
+ }
+
+ char buffer[100];
+
+ inp >> setw(100) >> buffer;
+
+ inp.close();
+
+ if ((strncmp(buffer,"dict",4)==0) ||
+ (strncmp(buffer,"DICT",4)==0))
+ load(filename);
+ else
+ generate(filename);
+
+ cerr << "loaded \n";
+
+}
+
+
+int dictionary::getword(fstream& inp , char* buffer) const
+{
+ while(inp >> setw(MAX_WORD) >> buffer) {
+
+ //warn if the word is very long
+ if (strlen(buffer)==(MAX_WORD-1)) {
+ cerr << "getword: a very long word was read ("
+ << buffer << ")\n";
+ }
+
+ //skip words of length zero chars: why should this happen?
+ if (strlen(buffer)==0) {
+ cerr << "zero length word!\n";
+ continue;
+ }
+
+ return 1;
+ }
+
+ return 0;
+}
+
+
+void dictionary::generate(char *filename,bool header)
+{
+
+ char buffer[MAX_WORD];
+ int counter=0;
+
+ mfstream inp(filename,ios::in);
+
+ if (!inp) {
+ std::stringstream ss_msg;
+ ss_msg << "cannot open " << filename << "\n";
+ exit_error(IRSTLM_ERROR_IO, ss_msg.str());
+ }
+
+ cerr << "dict:";
+
+ ifl=1;
+
+ //skip header
+ if (header) inp.getline(buffer,MAX_WORD);
+
+ while (getword(inp,buffer)) {
+
+ incfreq(encode(buffer),1);
+
+ if (!(++counter % 1000000)) cerr << ".";
+ }
+
+ ifl=0;
+
+ cerr << "\n";
+
+ inp.close();
+
+}
+
+void dictionary::augment(dictionary *d)
+{
+ incflag(1);
+ for (int i=0; i<d->n; i++)
+ encode(d->decode(i));
+ incflag(0);
+ encode(OOV());
+}
+
+
+// print_curve: show statistics on dictionary growth
+void dictionary::print_curve_growth(int curvesize) const
+{
+ int* curve = new int[curvesize];
+ for (int i=0; i<curvesize; i++) curve[i]=0;
+
+ // filling the curve
+ for (int i=0; i<n; i++) {
+ if(tb[i].freq > curvesize-1)
+ curve[curvesize-1]++;
+ else
+ curve[tb[i].freq-1]++;
+ }
+
+ //cumulating results
+ for (int i=curvesize-2; i>=0; i--) {
+ curve[i] = curve[i] + curve[i+1];
+ }
+
+ cout.setf(ios::fixed);
+ cout << "Dict size: " << n << "\n";
+ cout << "**************** DICTIONARY GROWTH CURVE ****************\n";
+ cout << "Freq\tEntries\tPercent";
+ cout << "\n";
+
+ for (int i=0; i<curvesize; i++) {
+ cout << ">" << i << "\t" << curve[i] << "\t" << setprecision(2) << (float)curve[i]/n * 100.0 << "%";
+ cout << "\n";
+ }
+ cout << "*********************************************************\n";
+ delete []curve;
+}
+
+// print_curve_oov: show OOV amount and OOV rates computed on test corpus
+void dictionary::print_curve_oov(int curvesize, const char *filename, int listflag)
+{
+ int *OOVchart=new int[curvesize];
+ int NwTest;
+
+ test(OOVchart, &NwTest, curvesize, filename, listflag);
+
+ cout.setf(ios::fixed);
+ cout << "Dict size: " << n << "\n";
+ cout << "Words of test: " << NwTest << "\n";
+ cout << "**************** OOV RATE STATISTICS ****************\n";
+ cout << "Freq\tOOV_Entries\tOOV_Rate";
+ cout << "\n";
+
+ for (int i=0; i<curvesize; i++) {
+
+ // display OOV iamount and OOV rates on test
+ cout << "<" << i+1 << "\t" << OOVchart[i] << "\t" << setprecision(2) << (float)OOVchart[i]/NwTest * 100.0 << "%";
+ cout << "\n";
+ }
+ cout << "*********************************************************\n";
+ delete []OOVchart;
+}
+
+//
+// test : compute OOV rates on test corpus using dictionaries of different sizes
+//
+void dictionary::test(int* OOVchart, int* NwTest, int curvesize, const char *filename, int listflag)
+{
+ MY_ASSERT(OOVchart!=NULL);
+
+ int m_NwTest=0;
+ for (int j=0; j<curvesize; j++) OOVchart[j]=0;
+ char buffer[MAX_WORD];
+
+ const char* bos = BoS();
+
+ mfstream inp(filename,ios::in);
+
+ if (!inp) {
+ std::stringstream ss_msg;
+ ss_msg << "cannot open " << filename << "\n";
+ exit_error(IRSTLM_ERROR_IO, ss_msg.str());
+ }
+ cerr << "test:";
+
+ int k = 0;
+ while (getword(inp,buffer)) {
+
+ // skip 'beginning of sentence' symbol
+ if (strcmp(buffer,bos)==0)
+ continue;
+
+ int freq = 0;
+ int wCode = getcode(buffer);
+ if(wCode!=-1) freq = tb[wCode].freq;
+
+ if(freq==0) {
+ OOVchart[0]++;
+ if(listflag) {
+ cerr << "<OOV>" << buffer << "</OOV>\n";
+ }
+ } else {
+ if(freq < curvesize) OOVchart[freq]++;
+ }
+ m_NwTest++;
+ if (!(++k % 1000000)) cerr << ".";
+ }
+ cerr << "\n";
+ inp.close();
+
+ // cumulating results
+ for (int i=1; i<curvesize; i++){
+ OOVchart[i] = OOVchart[i] + OOVchart[i-1];
+ }
+ *NwTest=m_NwTest;
+}
+
+void dictionary::load(char* filename)
+{
+ char header[100];
+ char* addr;
+ char buffer[MAX_WORD];
+ int freqflag=0;
+
+ mfstream inp(filename,ios::in);
+
+ if (!inp) {
+ std::stringstream ss_msg;
+ ss_msg << "cannot open " << filename << "\n";
+ exit_error(IRSTLM_ERROR_IO, ss_msg.str());
+ }
+
+ cerr << "dict:";
+
+ inp.getline(header,100);
+ if (strncmp(header,"DICT",4)==0)
+ freqflag=1;
+ else if (strncmp(header,"dict",4)!=0) {
+ std::stringstream ss_msg;
+ ss_msg << "dictionary file " << filename << " has a wrong header";
+ exit_error(IRSTLM_ERROR_DATA, ss_msg.str());
+ }
+
+
+ while (getword(inp,buffer)) {
+
+ tb[n].word=st->push(buffer);
+ tb[n].code=n;
+
+ if (freqflag)
+ inp >> tb[n].freq;
+ else
+ tb[n].freq=0;
+
+ //always insert without checking whether the word is already in
+ if ((addr=htb->insert((char*)&tb[n].word))) {
+ if (addr!=(char *)&tb[n].word) {
+ cerr << "dictionary::loadtxt wrong entry was found ("
+ << buffer << ") in position " << n << "\n";
+ // exit(1);
+ continue; // continue loading dictionary
+ }
+ }
+
+ N+=tb[n].freq;
+
+ if (strcmp(buffer,OOV())==0) oov_code=n;
+
+ if (++n==lim) grow();
+
+ }
+
+ inp.close();
+}
+
+
+void dictionary::load(std::istream& inp)
+{
+
+ char buffer[MAX_WORD];
+ char *addr;
+ int size;
+
+ inp >> size;
+
+ for (int i=0; i<size; i++) {
+
+ inp >> setw(MAX_WORD) >> buffer;
+
+ tb[n].word=st->push(buffer);
+ tb[n].code=n;
+ inp >> tb[n].freq;
+ N+=tb[n].freq;
+
+ //always insert without checking whether the word is already in
+ if ((addr=htb->insert((char *)&tb[n].word))) {
+ if (addr!=(char *)&tb[n].word) {
+ std::stringstream ss_msg;
+ ss_msg << "dictionary::loadtxt wrong entry was found (" << buffer << ") in position " << n;
+ exit_error(IRSTLM_ERROR_DATA, ss_msg.str());
+ }
+ }
+
+ if (strcmp(tb[n].word,OOV())==0)
+ oov_code=n;
+
+ if (++n==lim) grow();
+ }
+ inp.getline(buffer,MAX_WORD-1);
+}
+
+
+void dictionary::save(std::ostream& out)
+{
+ out << n << "\n";
+ for (int i=0; i<n; i++)
+ out << tb[i].word << " " << tb[i].freq << "\n";
+}
+
+int cmpdictentry(const void *a,const void *b)
+{
+ dict_entry *ae=(dict_entry *)a;
+ dict_entry *be=(dict_entry *)b;
+
+ if (be->freq-ae->freq)
+ return be->freq-ae->freq;
+ else
+ return strcmp(ae->word,be->word);
+
+}
+
+
+dictionary::dictionary(dictionary* d,bool prune, int prunethresh)
+{
+ MY_ASSERT(d!=NULL);
+ //transfer values
+ n=0; //total entries
+ N=0; //total frequency
+
+ load_factor=d->load_factor; //load factor
+ lim=d->lim; //limit of entries
+ oov_code=-1; //code od oov must be re-defined
+ ifl=0; //increment flag=0;
+ dubv=d->dubv; //dictionary upperbound transferred
+
+ //creates a sorted copy of the table
+ tb = new dict_entry[lim];
+ htb = new HASHTABLE_t((size_t) (lim/load_factor));
+ st = new strstack(lim * 10);
+
+ //copy in the entries with frequency > threshold
+ n=0;
+ for (int i=0; i<d->n; i++)
+ if (!prune || d->tb[i].freq>prunethresh){
+ tb[n].code=n;
+ tb[n].freq=d->tb[i].freq;
+ tb[n].word=st->push(d->tb[i].word);
+ htb->insert((char*)&tb[n].word);
+
+ if (d->oov_code==i) oov_code=n; //reassign oov_code
+
+ N+=tb[n].freq;
+ n++;
+ }
+};
+
+void dictionary::sort()
+{
+ if (htb != NULL ) delete htb;
+
+ htb = new HASHTABLE_t((int) (lim/load_factor));
+ //sort all entries according to frequency
+ cerr << "sorting dictionary ...";
+ qsort(tb,n,sizeof(dict_entry),cmpdictentry);
+ cerr << "done\n";
+
+ for (int i=0; i<n; i++) {
+ //eventually re-assign oov code
+ if (oov_code==tb[i].code) oov_code=i;
+ tb[i].code=i;
+ //always insert without checking whether the word is already in
+ htb->insert((char*)&tb[i].word);
+ };
+
+}
+
+dictionary::~dictionary()
+{
+ delete htb;
+ delete st;
+ delete [] tb;
+}
+
+void dictionary::stat() const
+{
+ cout << "dictionary class statistics\n";
+ cout << "size " << n
+ << " used memory "
+ << (lim * sizeof(int) +
+ htb->used() +
+ st->used())/1024 << " Kb\n";
+}
+
+void dictionary::grow()
+{
+ delete htb;
+
+ cerr << "+\b";
+
+ int newlim=(int) (lim*GROWTH_STEP);
+ dict_entry *tb2=new dict_entry[newlim];
+
+ memcpy(tb2,tb,sizeof(dict_entry) * lim );
+
+ delete [] tb;
+ tb=tb2;
+
+ htb=new HASHTABLE_t((size_t) ((newlim)/load_factor));
+ for (int i=0; i<lim; i++) {
+ //always insert without checking whether the word is already in
+ htb->insert((char*)&tb[i].word);
+ }
+
+ for (int i=lim; i<newlim; i++) tb[i].freq=0;
+
+ lim=newlim;
+}
+
+void dictionary::save(char *filename,int freqflag)
+{
+
+ std::ofstream out(filename,ios::out);
+
+ if (!out) {
+ cerr << "cannot open " << filename << "\n";
+ }
+
+ // header
+ if (freqflag)
+ out << "DICTIONARY 0 " << n << "\n";
+ else
+ out << "dictionary 0 " << n << "\n";
+
+ for (int i=0; i<n; i++)
+ if (tb[i].freq) { //do not print pruned words!
+ out << tb[i].word;
+ if (freqflag)
+ out << " " << tb[i].freq;
+ out << "\n";
+ }
+
+ out.close();
+}
+
+
+int dictionary::getcode(const char *w)
+{
+ dict_entry* ptr=(dict_entry *)htb->find((char *)&w);
+ if (ptr==NULL) return -1;
+ return ptr->code;
+}
+
+int dictionary::encode(const char *w)
+{
+ //case of strange characters
+ if (strlen(w)==0) {
+ cerr << "0";
+ w=OOV();
+ }
+
+
+ dict_entry* ptr;
+
+ if ((ptr=(dict_entry *)htb->find((char *)&w))!=NULL)
+ return ptr->code;
+ else {
+ if (!ifl) { //do not extend dictionary
+ if (oov_code==-1) { //did not use OOV yet
+ cerr << "starting to use OOV words [" << w << "]\n";
+ tb[n].word=st->push(OOV());
+ htb->insert((char *)&tb[n].word);
+ tb[n].code=n;
+ tb[n].freq=0;
+ oov_code=n;
+ if (++n==lim) grow();
+ }
+
+ return encode(OOV());
+ } else { //extend dictionary
+ tb[n].word=st->push((char *)w);
+ htb->insert((char*)&tb[n].word);
+ tb[n].code=n;
+ tb[n].freq=0;
+ if (++n==lim) grow();
+ return n-1;
+ }
+ }
+}
+
+
+const char *dictionary::decode(int c) const
+{
+ if (c>=0 && c < n)
+ return tb[c].word;
+ else {
+ cerr << "decode: code out of boundary\n";
+ return OOV();
+ }
+}
+
+
+dictionary_iter::dictionary_iter(dictionary *dict) : m_dict(dict)
+{
+ m_dict->scan(HT_INIT);
+}
+
+dict_entry* dictionary_iter::next()
+{
+ return (dict_entry*) m_dict->scan(HT_CONT);
+}
+
+/*
+ main(int argc,char **argv){
+ dictionary d(argv[1],40000);
+ d.stat();
+ cout << "ROMA" << d.decode(0) << "\n";
+ cout << "ROMA:" << d.encode("ROMA") << "\n";
+ d.save(argv[2]);
+ }
+ */
diff --git a/src/dictionary.h b/src/dictionary.h
new file mode 100644
index 0000000..9423c8f
--- /dev/null
+++ b/src/dictionary.h
@@ -0,0 +1,249 @@
+// $Id: dictionary.h 3679 2010-10-13 09:10:01Z bertoldi $
+
+/******************************************************************************
+ IrstLM: IRST Language Model Toolkit
+ Copyright (C) 2006 Marcello Federico, ITC-irst Trento, Italy
+
+ This library is free software; you can redistribute it and/or
+ modify it under the terms of the GNU Lesser General Public
+ License as published by the Free Software Foundation; either
+ version 2.1 of the License, or (at your option) any later version.
+
+ This library is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+ Lesser General Public License for more details.
+
+ You should have received a copy of the GNU Lesser General Public
+ License along with this library; if not, write to the Free Software
+ Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+
+ ******************************************************************************/
+
+#ifndef MF_DICTIONARY_H
+#define MF_DICTIONARY_H
+
+#include "mfstream.h"
+#include "htable.h"
+#include <cstring>
+#include <iostream>
+
+
+using namespace std;
+
+#define MAX_WORD 1000
+#define DICTIONARY_LOAD_FACTOR 2.0
+
+
+#ifndef GROWTH_STEP
+#define GROWTH_STEP 1.5
+#endif
+
+#ifndef DICT_INITSIZE
+#define DICT_INITSIZE 100000
+#endif
+
+//Begin of sentence symbol
+#ifndef BOS_
+#define BOS_ "<s>"
+#endif
+
+
+//End of sentence symbol
+#ifndef EOS_
+#define EOS_ "</s>"
+#endif
+
+//End of document symbol
+#ifndef BOD_
+#define BOD_ "<d>"
+#endif
+
+//End of document symbol
+#ifndef EOD_
+#define EOD_ "</d>"
+#endif
+
+
+//Out-Of-Vocabulary symbol
+#ifndef OOV_
+#define OOV_ "<unk>"
+#endif
+
+typedef struct {
+ const char *word;
+ int code;
+ long long freq;
+} dict_entry;
+
+typedef htable<char*> HASHTABLE_t;
+
+class strstack;
+
+class dictionary
+{
+ strstack *st; //!< stack of strings
+ dict_entry *tb; //!< entry table
+ HASHTABLE_t *htb; //!< hash table
+ int n; //!< number of entries
+ long long N; //!< total frequency
+ int lim; //!< limit of entries
+ int oov_code; //!< code assigned to oov words
+ char ifl; //!< increment flag
+ int dubv; //!< dictionary size upper bound
+ float load_factor; //!< dictionary loading factor
+ char* oov_str; //!< oov string
+
+ void test(int* OOVchart, int* NwTest, int curvesize, const char *filename, int listflag=0); // prepare into testOOV the OOV statistics computed on test set
+
+public:
+
+ friend class dictionary_iter;
+
+ dictionary* oovlex; //<! additional dictionary
+
+ inline int dub() const {
+ return dubv;
+ }
+
+ inline int dub(int value) {
+ return dubv=value;
+ }
+
+ inline static const char *OOV() {
+ return (char*) OOV_;
+ }
+
+ inline static const char *BoS() {
+ return (char*) BOS_;
+ }
+
+ inline static const char *EoS() {
+ return (char*) EOS_;
+ }
+
+ inline static const char *BoD() {
+ return (char*) BOD_;
+ }
+
+ inline static const char *EoD() {
+ return (char*) EOD_;
+ }
+
+ inline int oovcode(int v=-1) {
+ return oov_code=(v>=0?v:oov_code);
+ }
+
+ inline int incflag() const {
+ return ifl;
+ }
+
+ inline int incflag(int v) {
+ return ifl=v;
+ }
+
+ int getword(fstream& inp , char* buffer) const;
+
+ int isprintable(char* w) const {
+ char buffer[MAX_WORD];
+ sprintf(buffer,"%s",w);
+ return strcmp(w,buffer)==0;
+ }
+
+ inline void genoovcode() {
+ int c=encode(OOV());
+ std::cerr << "OOV code is "<< c << std::endl;
+ cerr << "OOV code is "<< c << std::endl;
+ oovcode(c);
+ }
+
+ inline void genBoScode() {
+ int c=encode(BoS());
+ std::cerr << "BoS code is "<< c << std::endl;
+ }
+
+ inline void genEoScode() {
+ int c=encode(EoS());
+ std::cerr << "EoS code is "<< c << std::endl;
+ }
+
+ inline long long setoovrate(double oovrate) {
+ encode(OOV()); //be sure OOV code exists
+ long long oovfreq=(long long)(oovrate * totfreq());
+ std::cerr << "setting OOV rate to: " << oovrate << " -- freq= " << oovfreq << std::endl;
+ return freq(oovcode(),oovfreq);
+ }
+
+ inline long long incfreq(int code,long long value) {
+ N+=value;
+ return tb[code].freq+=value;
+ }
+
+ inline long long multfreq(int code,double value) {
+ N+=(long long)(value * tb[code].freq)-tb[code].freq;
+ return tb[code].freq=(long long)(value * tb[code].freq);
+ }
+
+ inline long long freq(int code,long long value=-1) {
+ if (value>=0) {
+ N+=value-tb[code].freq;
+ tb[code].freq=value;
+ }
+ return tb[code].freq;
+ }
+
+ inline long long totfreq() const {
+ return N;
+ }
+
+ inline float set_load_factor(float value) {
+ return load_factor=value;
+ }
+
+ void grow();
+ void sort();
+
+ dictionary(char *filename,int size=DICT_INITSIZE,float lf=DICTIONARY_LOAD_FACTOR);
+ dictionary(dictionary* d, bool prune=false,int prunethresh=0); //make a copy and eventually filter out unfrequent words
+
+ ~dictionary();
+ void generate(char *filename,bool header=false);
+ void load(char *filename);
+ void save(char *filename, int freqflag=0);
+ void load(std::istream& fd);
+ void save(std::ostream& fd);
+
+ void augment(dictionary *d);
+
+ int size() const {
+ return n;
+ }
+ int getcode(const char *w);
+ int encode(const char *w);
+
+ const char *decode(int c) const;
+ void stat() const;
+
+ void print_curve_growth(int curvesize) const;
+ void print_curve_oov(int curvesize, const char *filename, int listflag=0);
+
+ void cleanfreq() {
+ for (int i=0; i<n; ++i){ tb[i].freq=0; };
+ N=0;
+ }
+
+ inline dict_entry* scan(HT_ACTION action) {
+ return (dict_entry*) htb->scan(action);
+ }
+};
+
+class dictionary_iter
+{
+public:
+ dictionary_iter(dictionary *dict);
+ dict_entry* next();
+private:
+ dictionary* m_dict;
+};
+#endif
+
diff --git a/src/doc.cpp b/src/doc.cpp
new file mode 100755
index 0000000..ec59756
--- /dev/null
+++ b/src/doc.cpp
@@ -0,0 +1,111 @@
+/******************************************************************************
+ IrstLM: IRST Language Model Toolkit, compile LM
+ Copyright (C) 2006 Marcello Federico, ITC-irst Trento, Italy
+
+ This library is free software; you can redistribute it and/or
+ modify it under the terms of the GNU Lesser General Public
+ License as published by the Free Software Foundation; either
+ version 2.1 of the License, or (at your option) any later version.
+
+ This library is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+ Lesser General Public License for more details.
+
+ You should have received a copy of the GNU Lesser General Public
+ License along with this library; if not, write to the Free Software
+ Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+
+******************************************************************************/
+#include <math.h>
+#include "util.h"
+#include "mfstream.h"
+#include "mempool.h"
+#include "htable.h"
+#include "dictionary.h"
+#include "n_gram.h"
+#include "doc.h"
+
+using namespace std;
+
+doc::doc(dictionary* d,char* docfname,bool use_null_word){
+ mfstream df(docfname,ios::in);
+
+ char header[100];
+ df.getline(header,100);
+ sscanf(header,"%d",&N);
+
+ assert(N>0 && N < MAXDOCNUM);
+
+
+ M=new int [N];
+ V=new int* [N];
+
+ int eod=d->encode(d->EoD());
+ int bod=d->encode(d->BoD());
+
+
+ ngram ng(d);
+ int n=0; //track documents
+ int m=0; //track document length
+ int w=0; //track words in doc
+
+ int tmp[MAXDOCLEN];
+
+ while (n<N && df >> ng)
+ if (ng.size>0){
+ w=*ng.wordp(1);
+ if (w==bod){
+ if (use_null_word){
+ ng.size=1; //use <d> as NULL word
+ }else{
+ ng.size=0; //skip <d>
+ continue;
+ }
+ }
+ if (w==eod && m>0){
+ M[n]=m; //length of n-th document
+ V[n]=new int[m];
+ memcpy(V[n],tmp,m * sizeof(int));
+ m=0;
+ n++;
+ continue;
+ }
+
+ if (m < MAXDOCLEN) tmp[m++]=w;
+ if (m==MAXDOCLEN) {cerr<< "warn: clipping long document (line " << n << " )\n";exit(1);};
+ }
+
+ cerr << "uploaded " << n << " documents\n";
+
+
+};
+
+doc::~doc(){
+ cerr << "releasing document storage\n";
+ for (int i=0;i<N;i++) delete [] V[i];
+ delete [] M; delete [] V;
+}
+
+
+int doc::numdoc(){
+ return N;
+}
+
+int doc::doclen( int index){
+ assert(index>=0 && index < N);
+ return M[index];
+}
+
+int doc::docword( int docindex, int wordindex){
+ assert(wordindex>=0 && wordindex<doclen(docindex));
+ return V[docindex][wordindex];
+}
+
+
+
+
+
+
+
+
diff --git a/src/doc.h b/src/doc.h
new file mode 100755
index 0000000..b187ccf
--- /dev/null
+++ b/src/doc.h
@@ -0,0 +1,43 @@
+/******************************************************************************
+ IrstLM: IRST Language Model Toolkit, compile LM
+ Copyright (C) 2006 Marcello Federico, ITC-irst Trento, Italy
+
+ This library is free software; you can redistribute it and/or
+ modify it under the terms of the GNU Lesser General Public
+ License as published by the Free Software Foundation; either
+ version 2.1 of the License, or (at your option) any later version.
+
+ This library is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+ Lesser General Public License for more details.
+
+ You should have received a copy of the GNU Lesser General Public
+ License along with this library; if not, write to the Free Software
+ Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+
+ ******************************************************************************/
+//class managing a collection of documents for PLSA
+
+
+#define MAXDOCLEN 500
+#define MAXDOCNUM 1000000000
+
+class doc
+{
+ int N; //number of docs
+ int *M; //number of words per document
+ int **V; //words in current doc
+
+public:
+
+ doc(dictionary* d,char* docfname,bool use_null_word=false);
+ ~doc();
+
+ int numdoc();
+ int doclen(int index);
+ int docword(int docindex, int wordindex);
+
+};
+
+
diff --git a/src/dtsel.cpp b/src/dtsel.cpp
new file mode 100644
index 0000000..79caab5
--- /dev/null
+++ b/src/dtsel.cpp
@@ -0,0 +1,397 @@
+// $Id: ngt.cpp 245 2009-04-02 14:05:40Z fabio_brugnara $
+
+/******************************************************************************
+IrstLM: IRST Language Model Toolkit
+Copyright (C) 2006 Marcello Federico, ITC-irst Trento, Italy
+
+This library is free software; you can redistribute it and/or
+modify it under the terms of the GNU Lesser General Public
+License as published by the Free Software Foundation; either
+version 2.1 of the License, or (at your option) any later version.
+
+This library is distributed in the hope that it will be useful,
+
+
+but WITHOUT ANY WARRANTY; without even the implied warranty of
+MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+Lesser General Public License for more details.
+
+You should have received a copy of the GNU Lesser General Public
+License along with this library; if not, write to the Free Software
+Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+
+******************************************************************************/
+
+// dtsel
+// by M. Federico
+// Copyright Marcello Federico, Fondazione Bruno Kessler, 2012
+
+
+#include <cmath>
+#include "util.h"
+#include <sstream>
+#include "mfstream.h"
+#include "mempool.h"
+#include "htable.h"
+#include "dictionary.h"
+#include "n_gram.h"
+#include "ngramtable.h"
+#include "cmd.h"
+
+using namespace std;
+
+#define YES 1
+#define NO 0
+
+void print_help(int TypeFlag=0){
+ std::cerr << std::endl << "dtsel - performs data selection" << std::endl;
+ std::cerr << std::endl << "USAGE:" << std::endl
+ << " dtsel -s=<outfile> [options]" << std::endl;
+ std::cerr << std::endl << "OPTIONS:" << std::endl;
+ FullPrintParams(TypeFlag, 0, 1, stderr);
+}
+
+void usage(const char *msg = 0)
+{
+ if (msg){
+ std::cerr << msg << std::endl;
+ }
+ else{
+ print_help();
+ }
+}
+
+double prob(ngramtable* ngt,ngram ng,int size,int cv){
+ MY_ASSERT(size<=ngt->maxlevel() && size<=ng.size);
+ if (size>1){
+ ngram history=ng;
+ if (ngt->get(history,size,size-1) && history.freq>cv){
+ double fstar=0.0;
+ double lambda;
+ if (ngt->get(ng,size,size)){
+ cv=(cv>ng.freq)?ng.freq:cv;
+ if (ng.freq>cv){
+ fstar=(double)(ng.freq-cv)/(double)(history.freq -cv + history.succ);
+ lambda=(double)history.succ/(double)(history.freq -cv + history.succ);
+ }else //ng.freq==cv
+ lambda=(double)(history.succ-1)/(double)(history.freq -cv + history.succ-1);
+ }
+ else
+ lambda=(double)history.succ/(double)(history.freq -cv + history.succ);
+
+ return fstar + lambda * prob(ngt,ng,size-1,cv);
+ }
+ else return prob(ngt,ng,size-1,cv);
+
+ }else{ //unigram branch
+ if (ngt->get(ng,1,1) && ng.freq>cv)
+ return (double)(ng.freq-cv)/(ngt->totfreq()-1);
+ else{
+ //cerr << "backoff to oov unigram " << ng.freq << " " << cv << "\n";
+ *ng.wordp(1)=ngt->dict->oovcode();
+ if (ngt->get(ng,1,1) && ng.freq>0)
+ return (double)ng.freq/ngt->totfreq();
+ else //use an automatic estimate of Pr(oov)
+ return (double)ngt->dict->size()/(ngt->totfreq()+ngt->dict->size());
+ }
+
+ }
+
+}
+
+
+double computePP(ngramtable* train,ngramtable* test,double oovpenalty,double& oovrate,int cv=0){
+
+
+ ngram ng2(test->dict);ngram ng1(train->dict);
+ int N=0; double H=0; oovrate=0;
+
+ test->scan(ng2,INIT,test->maxlevel());
+ while(test->scan(ng2,CONT,test->maxlevel())) {
+
+ ng1.trans(ng2);
+ H-=log(prob(train,ng1,ng1.size,cv));
+ if (*ng1.wordp(1)==train->dict->oovcode()){
+ H-=oovpenalty;
+ oovrate++;
+ }
+ N++;
+ }
+ oovrate/=N;
+ return exp(H/N);
+}
+
+
+int main(int argc, char **argv)
+{
+ char *indom=NULL; //indomain data: one sentence per line
+ char *outdom=NULL; //domain data: one sentence per line
+ char *scorefile=NULL; //score file
+ char *evalset=NULL; //evalset to measure performance
+
+ int minfreq=2; //frequency threshold for dictionary pruning (optional)
+ int ngsz=0; // n-gram size
+ int dub=10000000; //upper bound of true vocabulary
+ int model=2; //data selection model: 1 only in-domain cross-entropy,
+ //2 cross-entropy difference.
+ int cv=1; //cross-validation parameter: 1 only in-domain cross-entropy,
+
+ int blocksize=100000; //block-size in words
+ int verbose=0;
+ int useindex=0; //provided score file includes and index
+ double convergence_treshold=0;
+
+ bool help=false;
+
+ DeclareParams((char*)
+ "min-word-freq", CMDINTTYPE|CMDMSG, &minfreq, "frequency threshold for dictionary pruning, default: 2",
+ "f", CMDINTTYPE|CMDMSG, &minfreq, "frequency threshold for dictionary pruning, default: 2",
+
+ "ngram-order", CMDSUBRANGETYPE|CMDMSG, &ngsz, 1 , MAX_NGRAM, "n-gram default size, default: 0",
+ "n", CMDSUBRANGETYPE|CMDMSG, &ngsz, 1 , MAX_NGRAM, "n-gram default size, default: 0",
+
+ "in-domain-file", CMDSTRINGTYPE|CMDMSG, &indom, "indomain data file: one sentence per line",
+ "i", CMDSTRINGTYPE|CMDMSG, &indom, "indomain data file: one sentence per line",
+
+ "out-domain-file", CMDSTRINGTYPE|CMDMSG, &outdom, "domain data file: one sentence per line",
+ "o", CMDSTRINGTYPE|CMDMSG, &outdom, "domain data file: one sentence per line",
+
+ "score-file", CMDSTRINGTYPE|CMDMSG, &scorefile, "score output file",
+ "s", CMDSTRINGTYPE|CMDMSG, &scorefile, "score output file",
+
+ "dictionary-upper-bound", CMDINTTYPE|CMDMSG, &dub, "upper bound of true vocabulary, default: 10000000",
+ "dub", CMDINTTYPE|CMDMSG, &dub, "upper bound of true vocabulary, default: 10000000",
+
+ "model", CMDSUBRANGETYPE|CMDMSG, &model, 1 , 2, "data selection model: 1 only in-domain cross-entropy, 2 cross-entropy difference; default: 2",
+ "m", CMDSUBRANGETYPE|CMDMSG, &model, 1 , 2, "data selection model: 1 only in-domain cross-entropy, 2 cross-entropy difference; default: 2",
+
+ "cross-validation", CMDSUBRANGETYPE|CMDMSG, &cv, 1 , 3, "cross-validation parameter: 1 only in-domain cross-entropy; default: 1",
+ "cv", CMDSUBRANGETYPE|CMDMSG, &cv, 1 , 3, "cross-validation parameter: 1 only in-domain cross-entropy; default: 1",
+
+ "test", CMDSTRINGTYPE|CMDMSG, &evalset, "evaluation set file to measure performance",
+ "t", CMDSTRINGTYPE|CMDMSG, &evalset, "evaluation set file to measure performance",
+
+ "block-size", CMDINTTYPE|CMDMSG, &blocksize, "block-size in words, default: 100000",
+ "bs", CMDINTTYPE|CMDMSG, &blocksize, "block-size in words, default: 100000",
+
+ "convergence-threshold", CMDDOUBLETYPE|CMDMSG, &convergence_treshold, "convergence threshold, default: 0",
+ "c", CMDDOUBLETYPE|CMDMSG, &convergence_treshold, "convergence threshold, default: 0",
+
+ "index", CMDSUBRANGETYPE|CMDMSG, &useindex,0,1, "provided score file includes and index, default: 0",
+ "x", CMDSUBRANGETYPE|CMDMSG, &useindex,0,1, "provided score file includes and index, default: 0",
+
+ "verbose", CMDSUBRANGETYPE|CMDMSG, &verbose,0,2, "verbose level, default: 0",
+ "v", CMDSUBRANGETYPE|CMDMSG, &verbose,0,2, "verbose level, default: 0",
+ "Help", CMDBOOLTYPE|CMDMSG, &help, "print this help",
+ "h", CMDBOOLTYPE|CMDMSG, &help, "print this help",
+
+ (char *)NULL
+ );
+
+
+
+ GetParams(&argc, &argv, (char*) NULL);
+
+ if (help){
+ usage();
+ exit_error(IRSTLM_NO_ERROR);
+ }
+ if (scorefile==NULL) {
+ usage();
+ exit_error(IRSTLM_NO_ERROR);
+ }
+
+ if (!evalset && (!indom || !outdom)){
+ exit_error(IRSTLM_ERROR_DATA, "Must specify in-domain and out-domain data files");
+ };
+
+ //score file is always required: either as output or as input
+ if (!scorefile){
+ exit_error(IRSTLM_ERROR_DATA, "Must specify score file");
+ };
+
+ if (!evalset && !model){
+ exit_error(IRSTLM_ERROR_DATA, "Must specify data selection model");
+ }
+
+ if (evalset && (convergence_treshold<0 || convergence_treshold > 0.1)){
+ exit_error(IRSTLM_ERROR_DATA, "Convergence threshold must be between 0 and 0.1");
+ }
+
+ TABLETYPE table_type=COUNT;
+
+
+ if (!evalset){
+
+ //computed dictionary on indomain data
+ dictionary *dict = new dictionary(indom,1000000,0);
+ dictionary *pd=new dictionary(dict,true,minfreq);
+ delete dict;dict=pd;
+
+ //build in-domain table restricted to the given dictionary
+ ngramtable *indngt=new ngramtable(indom,ngsz,NULL,dict,NULL,0,0,NULL,0,table_type);
+ double indoovpenalty=-log(dub-indngt->dict->size());
+ ngram indng(indngt->dict);
+ int indoovcode=indngt->dict->oovcode();
+
+ //build out-domain table restricted to the in-domain dictionary
+ char command[1000]="";
+
+ if (useindex)
+ sprintf(command,"cut -d \" \" -f 2- %s",outdom);
+ else
+ sprintf(command,"%s",outdom);
+
+
+ ngramtable *outdngt=new ngramtable(command,ngsz,NULL,dict,NULL,0,0,NULL,0,table_type);
+ double outdoovpenalty=-log(dub-outdngt->dict->size());
+ ngram outdng(outdngt->dict);
+ int outdoovcode=outdngt->dict->oovcode();
+
+ cerr << "dict size idom: " << indngt->dict->size() << " odom: " << outdngt->dict->size() << "\n";
+ cerr << "oov penalty idom: " << indoovpenalty << " odom: " << outdoovpenalty << "\n";
+
+ //go through the odomain sentences
+ int bos=dict->encode(dict->BoS());
+ mfstream inp(outdom,ios::in); ngram ng(dict);
+ mfstream txt(outdom,ios::in);
+ mfstream output(scorefile,ios::out);
+
+
+ int linenumber=1; string line;
+ int length=0;
+ float deltaH=0;
+ float deltaHoov=0;
+ int words=0;string index;
+
+ while (getline(inp,line)){
+
+ istringstream lninp(line);
+
+ linenumber++;
+
+ if (useindex) lninp >> index;
+
+ // reset ngram at begin of sentence
+ ng.size=1; deltaH=0;deltaHoov=0; length=0;
+
+ while(lninp>>ng){
+
+ if (*ng.wordp(1)==bos) continue;
+
+ length++; words++;
+
+ if ((words % 1000000)==0) cerr << ".";
+
+ if (ng.size>ngsz) ng.size=ngsz;
+ indng.trans(ng);outdng.trans(ng);
+
+ if (model==1){//compute cross-entropy
+ deltaH-=log(prob(indngt,indng,indng.size,0));
+ deltaHoov-=(*indng.wordp(1)==indoovcode?indoovpenalty:0);
+ }
+
+ if (model==2){ //compute cross-entropy difference
+ deltaH+=log(prob(outdngt,outdng,outdng.size,cv))-log(prob(indngt,indng,indng.size,0));
+ deltaHoov+=(*outdng.wordp(1)==outdoovcode?outdoovpenalty:0)-(*indng.wordp(1)==indoovcode?indoovpenalty:0);
+ }
+ }
+
+ output << (deltaH + deltaHoov)/length << " " << line << "\n";
+ }
+ }
+ else{
+
+ //build in-domain LM from evaluation set
+ ngramtable *tstngt=new ngramtable(evalset,ngsz,NULL,NULL,NULL,0,0,NULL,0,table_type);
+
+ //build empty out-domain LM
+ ngramtable *outdngt=new ngramtable(NULL,ngsz,NULL,NULL,NULL,0,0,NULL,0,table_type);
+
+ //if indomain data is passed then limit comparison to its dictionary
+ dictionary *dict = NULL;
+ if (indom){
+ cerr << "dtsel: limit evaluation dict to indomain words with freq >=" << minfreq << "\n";
+ //computed dictionary on indomain data
+ dict = new dictionary(indom,1000000,0);
+ dictionary *pd=new dictionary(dict,true,minfreq);
+ delete dict;dict=pd;
+ outdngt->dict=dict;
+ }
+
+ dictionary* outddict=outdngt->dict;
+
+ //get codes of <s>, </s> and UNK
+ outddict->incflag(1);
+ int bos=outddict->encode(outddict->BoS());
+ int oov=outddict->encode(outddict->OOV());
+ outddict->incflag(0);
+ outddict->oovcode(oov);
+
+
+ double oldPP=dub; double newPP=0; double oovrate=0;
+
+ long totwords=0; long totlines=0; long nextstep=blocksize;
+
+ double score; string index;
+
+ mfstream outd(scorefile,ios::in); string line;
+
+ //initialize n-gram
+ ngram ng(outdngt->dict); for (int i=1;i<ngsz;i++) ng.pushc(bos); ng.freq=1;
+
+ //check if to use open or closed voabulary
+
+ if (!dict) outddict->incflag(1);
+
+ while (getline(outd,line)){
+
+ istringstream lninp(line);
+
+ //skip score and eventually the index
+ lninp >> score; if (useindex) lninp >> index;
+
+ while (lninp >> ng){
+
+ if (*ng.wordp(1) == bos) continue;
+
+ if (ng.size>ngsz) ng.size=ngsz;
+
+ outdngt->put(ng);
+
+ totwords++;
+ }
+
+ totlines++;
+
+ if (totwords>=nextstep){ //if block is complete
+
+ if (!dict) outddict->incflag(0);
+
+ newPP=computePP(outdngt,tstngt,-log(dub-outddict->size()),oovrate);
+
+ if (!dict) outddict->incflag(1);
+
+ cout << totwords << " " << newPP;
+ if (verbose) cout << " " << totlines << " " << oovrate;
+ cout << "\n";
+
+ if (convergence_treshold && (oldPP-newPP)/oldPP < convergence_treshold) return 1;
+
+ oldPP=newPP;
+
+ nextstep+=blocksize;
+ }
+ }
+
+ if (!dict) outddict->incflag(0);
+ newPP=computePP(outdngt,tstngt,-log(dub-outddict->size()),oovrate);
+ cout << totwords << " " << newPP;
+ if (verbose) cout << " " << totlines << " " << oovrate;
+
+ }
+
+}
+
+
+
diff --git a/src/gzfilebuf.h b/src/gzfilebuf.h
new file mode 100644
index 0000000..4699943
--- /dev/null
+++ b/src/gzfilebuf.h
@@ -0,0 +1,90 @@
+// $Id: gzfilebuf.h 236 2009-02-03 13:25:19Z nicolabertoldi $
+
+#ifndef _GZFILEBUF_H_
+#define _GZFILEBUF_H_
+
+#include <cstdio>
+#include <streambuf>
+#include <cstring>
+#include <zlib.h>
+#include <iostream>
+
+class gzfilebuf : public std::streambuf
+{
+public:
+ gzfilebuf(const char *filename) {
+ _gzf = gzopen(filename, "rb");
+ setg (_buff+sizeof(int), // beginning of putback area
+ _buff+sizeof(int), // read position
+ _buff+sizeof(int)); // end position
+ }
+ ~gzfilebuf() {
+ gzclose(_gzf);
+ }
+protected:
+ virtual int_type overflow (int_type /* unused parameter: c */) {
+ std::cerr << "gzfilebuf::overflow is not implemented" << std::endl;;
+ throw;
+ }
+
+ // write multiple characters
+ virtual std::streamsize xsputn (const char* /* unused parameter: s */, std::streamsize /* unused parameter: num */) {
+ std::cerr << "gzfilebuf::xsputn is not implemented" << std::endl;;
+ throw;
+ }
+
+ virtual std::streampos seekpos ( std::streampos /* unused parameter: sp */, std::ios_base::openmode /* unused parameter: which */= std::ios_base::in | std::ios_base::out ) {
+ std::cerr << "gzfilebuf::seekpos is not implemented" << std::endl;;
+ throw;
+ }
+
+ //read one character
+ virtual int_type underflow () {
+ // is read position before end of _buff?
+ if (gptr() < egptr()) {
+ return traits_type::to_int_type(*gptr());
+ }
+
+ /* process size of putback area
+ * - use number of characters read
+ * - but at most four
+ */
+ unsigned int numPutback = gptr() - eback();
+ if (numPutback > sizeof(int)) {
+ numPutback = sizeof(int);
+ }
+
+ /* copy up to four characters previously read into
+ * the putback _buff (area of first four characters)
+ */
+ std::memmove (_buff+(sizeof(int)-numPutback), gptr()-numPutback,
+ numPutback);
+
+ // read new characters
+ int num = gzread(_gzf, _buff+sizeof(int), _buffsize-sizeof(int));
+ if (num <= 0) {
+ // ERROR or EOF
+ return EOF;
+ }
+
+ // reset _buff pointers
+ setg (_buff+(sizeof(int)-numPutback), // beginning of putback area
+ _buff+sizeof(int), // read position
+ _buff+sizeof(int)+num); // end of buffer
+
+ // return next character
+ return traits_type::to_int_type(*gptr());
+ }
+
+ std::streamsize xsgetn (char* s,
+ std::streamsize num) {
+ return gzread(_gzf,s,num);
+ }
+
+private:
+ gzFile _gzf;
+ static const unsigned int _buffsize = 1024;
+ char _buff[_buffsize];
+};
+
+#endif
diff --git a/src/htable.cpp b/src/htable.cpp
new file mode 100644
index 0000000..99e8388
--- /dev/null
+++ b/src/htable.cpp
@@ -0,0 +1,107 @@
+// $Id: htable.cpp 3680 2010-10-13 09:10:21Z bertoldi $
+
+/******************************************************************************
+ IrstLM: IRST Language Model Toolkit
+ Copyright (C) 2006 Marcello Federico, ITC-irst Trento, Italy
+
+ This library is free software; you can redistribute it and/or
+ modify it under the terms of the GNU Lesser General Public
+ License as published by the Free Software Foundation; either
+ version 2.1 of the License, or (at your option) any later version.
+
+ This library is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+ Lesser General Public License for more details.
+
+ You should have received a copy of the GNU Lesser General Public
+ License along with this library; if not, write to the Free Software
+ Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+
+******************************************************************************/
+
+#include <stdio.h>
+#include <stdlib.h>
+#include <string.h>
+#include <iostream>
+#include "mempool.h"
+#include "htable.h"
+#include "util.h"
+
+using namespace std;
+
+template <>
+void htable<int*>::set_keylen(int kl)
+{
+ keylen=kl/sizeof(int);
+ return;
+}
+
+template <>
+void htable<char*>::set_keylen(int kl)
+{
+ keylen=kl;
+ return;
+}
+
+template <>
+address htable<int *>::Hash(int* key)
+{
+ address h;
+ register int i;
+
+ //Thomas Wang's 32 bit Mix Function
+ for (i=0,h=0; i<keylen; i++) {
+ h+=key[i];
+ h += ~(h << 15);
+ h ^= (h >> 10);
+ h += (h << 3);
+ h ^= (h >> 6);
+ h += ~(h << 11);
+ h ^= (h >> 16);
+ };
+
+ return h;
+}
+
+template <>
+address htable<char *>::Hash(char* key)
+{
+ //actually char* key is a char**, i.e. a pointer to a char*
+ char *Key = *(char**)key;
+ int length=strlen(Key);
+
+ register address h=0;
+ register int i;
+
+ for (i=0,h=0; i<length; i++)
+ h = h * Prime1 ^ (Key[i] - ' ');
+ h %= Prime2;
+
+ return h;
+}
+
+template <>
+int htable<int*>::Comp(int *key1, int *key2) const
+{
+ MY_ASSERT(key1 && key2);
+
+ register int i;
+
+ for (i=0; i<keylen; i++)
+ if (key1[i]!=key2[i]) return 1;
+ return 0;
+}
+
+template <>
+int htable<char*>::Comp(char *key1, char *key2) const
+{
+ MY_ASSERT(key1 && key2);
+
+ char *Key1 = *(char**)key1;
+ char *Key2 = *(char**)key2;
+
+ MY_ASSERT(Key1 && Key2);
+
+ return (strcmp(Key1,Key2));
+}
diff --git a/src/htable.h b/src/htable.h
new file mode 100644
index 0000000..54ca184
--- /dev/null
+++ b/src/htable.h
@@ -0,0 +1,270 @@
+// $Id: htable.h 3680 2010-10-13 09:10:21Z bertoldi $
+
+/******************************************************************************
+ IrstLM: IRST Language Model Toolkit
+ Copyright (C) 2006 Marcello Federico, ITC-irst Trento, Italy
+
+ This library is free software; you can redistribute it and/or
+ modify it under the terms of the GNU Lesser General Public
+ License as published by the Free Software Foundation; either
+ version 2.1 of the License, or (at your option) any later version.
+
+ This library is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+ Lesser General Public License for more details.
+
+ You should have received a copy of the GNU Lesser General Public
+ License along with this library; if not, write to the Free Software
+ Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+
+ ******************************************************************************/
+
+#ifndef MF_HTABLE_H
+#define MF_HTABLE_H
+
+using namespace std;
+
+#include <iostream>
+#include <string>
+#include <cstring>
+#include "mempool.h"
+
+#define Prime1 37
+#define Prime2 1048583
+#define BlockSize 100
+
+typedef unsigned int address;
+
+// Fast arithmetic, relying on powers of 2,
+// and on pre-processor concatenation property
+//use as
+template <class T>
+struct entry {
+ T key;
+ entry* next; // secret from user
+};
+
+
+typedef enum {HT_FIND, //!< search: find an entry
+ HT_ENTER, //!< search: enter an entry
+ HT_INIT, //!< scan: start scan
+ HT_CONT //!< scan: continue scan
+} HT_ACTION;
+
+//!T is the type of the key and should be (int*) or (char*)
+template <class T>
+class htable
+{
+ int size; //!< table size
+ int keylen; //!< key length
+ entry<T> **table; //!< hash table
+ int scan_i; //!< scan support
+ entry<T> *scan_p; //!< scan support
+ // statistics
+ long keys; //!< # of entries
+ long accesses; //!< # of accesses
+ long collisions; //!< # of collisions
+
+ mempool *memory; //!< memory pool
+
+public:
+
+ //! Creates an hash table
+ htable(int n,int kl=0);
+
+ //! Destroys an and hash table
+ ~htable();
+
+ void set_keylen(int kl);
+
+ //! Computes the hash function
+ address Hash(const T key);
+
+ //! Compares the keys of two entries
+ int Comp(const T Key1, const T Key2) const;
+
+ //! Searches for an item
+ T find(T item);
+ T insert(T item);
+
+ //! Scans the content
+ T scan(HT_ACTION action);
+
+ //! Prints statistics
+ void stat() const ;
+
+ //! Print a map of memory use
+ void map(std::ostream& co=std::cout, int cols=80);
+
+ //! Returns amount of used memory
+ int used() const {
+ return size * sizeof(entry<T> **) + memory->used();
+ }
+
+};
+
+
+
+template <class T>
+htable<T>::htable(int n,int kl)
+{
+
+ memory=new mempool( sizeof(entry<T>) , BlockSize );
+
+ table = new entry<T>* [ size=n ];
+
+ memset(table,0,sizeof(entry<T> *) * n );
+
+ set_keylen(kl);
+
+ keys = accesses = collisions = 0;
+}
+
+template <class T>
+htable<T>::~htable()
+{
+ delete []table;
+ delete memory;
+}
+
+template <class T>
+T htable<T>::find(T key)
+{
+// std::cerr << "T htable<T>::find(T key) size:" << size << std::endl;
+ address h;
+ entry<T> *q,**p;
+
+ accesses++;
+
+ h = Hash(key);
+// std::cerr << "T htable<T>::find(T key) h:" << h << std::endl;
+
+ p=&table[h%size];
+ q=*p;
+
+ /* Follow collision chain */
+ while (q != NULL && Comp(q->key,key)) {
+ p = &(q->next);
+ q = q->next;
+
+ collisions++;
+ }
+
+ if (q != NULL) return q->key; /* found */
+
+ return NULL;
+}
+
+template <class T>
+T htable<T>::insert(T key)
+{
+ address h;
+ entry<T> *q,**p;
+
+ accesses++;
+
+ h = Hash(key);
+
+ p=&table[h%size];
+ q=*p;
+
+ /* Follow collision chain */
+ while (q != NULL && Comp(q->key,key)) {
+ p = &(q->next);
+ q = q->next;
+
+ collisions++;
+ }
+
+ if (q != NULL) return q->key; /* found */
+
+ /* not found */
+ if ((q = (entry<T> *)memory->allocate()) == NULL) /* no room */
+ return NULL;
+
+ /* link into chain */
+ *p = q;
+
+ /* Initialize new element */
+ q->key = key;
+ q->next = NULL;
+ keys++;
+
+ return q->key;
+}
+
+template <class T>
+T htable<T>::scan(HT_ACTION action)
+{
+ if (action == HT_INIT) {
+ scan_i=0;
+ scan_p=table[0];
+ return NULL;
+ }
+
+ // if scan_p==NULL go to the first non null pointer
+ while ((scan_p==NULL) && (++scan_i<size)) scan_p=table[scan_i];
+
+ if (scan_p!=NULL) {
+ T k = scan_p->key;
+ scan_p=(entry<T> *)scan_p->next;
+ return k;
+ };
+
+ return NULL;
+}
+
+
+template <class T>
+void htable<T>::map(ostream& co,int cols)
+{
+
+ entry<T> *p;
+ char* img=new char[cols+1];
+
+ img[cols]='\0';
+ memset(img,'.',cols);
+
+ co << "htable memory map: . (0 items), - (<5), # (>5)\n";
+
+ for (int i=0; i<size; i++) {
+ int n=0;
+ p=table[i];
+
+ while(p!=NULL) {
+ n++;
+ p=(entry<T> *)p->next;
+ };
+
+ if (i && (i % cols)==0) {
+ co << img << "\n";
+ memset(img,'.',cols);
+ }
+
+ if (n>0)
+ img[i % cols]=n<=5?'-':'#';
+
+ }
+
+ img[size % cols]='\0';
+ co << img << "\n";
+
+ delete []img;
+}
+
+template <class T>
+void htable<T>::stat() const
+{
+ cerr << "htable class statistics\n";
+ cerr << "size " << size
+ << " keys " << keys
+ << " acc " << accesses
+ << " coll " << collisions
+ << " used memory " << used()/1024 << "Kb\n";
+};
+
+#endif
+
+
+
diff --git a/src/index.h b/src/index.h
new file mode 100644
index 0000000..3bfbd60
--- /dev/null
+++ b/src/index.h
@@ -0,0 +1,18 @@
+// $Id: index.h 236 2009-02-03 13:25:19Z nicolabertoldi $
+
+#pragma once
+
+#ifdef WIN32
+
+inline const char *index(const char *str, char search)
+{
+ size_t i=0;
+ while (i< strlen(str) ) {
+ if (str[i]==search) return &str[i];
+ }
+ return NULL;
+}
+
+#endif
+
+
diff --git a/src/interplm.cpp b/src/interplm.cpp
new file mode 100644
index 0000000..c671409
--- /dev/null
+++ b/src/interplm.cpp
@@ -0,0 +1,536 @@
+/******************************************************************************
+IrstLM: IRST Language Model Toolkit
+Copyright (C) 2006 Marcello Federico, ITC-irst Trento, Italy
+
+This library is free software; you can redistribute it and/or
+modify it under the terms of the GNU Lesser General Public
+License as published by the Free Software Foundation; either
+version 2.1 of the License, or (at your option) any later version.
+
+This library is distributed in the hope that it will be useful,
+but WITHOUT ANY WARRANTY; without even the implied warranty of
+MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+Lesser General Public License for more details.
+
+You should have received a copy of the GNU Lesser General Public
+License along with this library; if not, write to the Free Software
+Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+
+******************************************************************************/
+
+#include <cmath>
+#include "util.h"
+#include "mfstream.h"
+#include "mempool.h"
+#include "htable.h"
+#include "dictionary.h"
+#include "n_gram.h"
+#include "mempool.h"
+#include "ngramcache.h"
+#include "ngramtable.h"
+#include "normcache.h"
+#include "interplm.h"
+
+using namespace std;
+
+void interplm::trainunigr()
+{
+
+ int oov=dict->getcode(dict->OOV());
+ cerr << "oovcode: " << oov << "\n";
+
+ if (oov>=0 && dict->freq(oov)>= dict->size()) {
+ cerr << "Using current estimate of OOV frequency " << dict->freq(oov)<< "\n";
+ } else {
+ oov=dict->encode(dict->OOV());
+ dict->oovcode(oov);
+
+ //choose unigram smoothing method according to
+ //sample size
+ //if (dict->totfreq()>100){ //witten bell
+ //cerr << "select unigram smoothing: " << dict->totfreq() << "\n";
+
+ if (unismooth) {
+ dict->incfreq(oov,dict->size()-1);
+ cerr << "Witten-Bell estimate of OOV freq:"<< (double)(dict->size()-1)/dict->totfreq() << "\n";
+ } else {
+ if (dict->dub()) {
+ cerr << "DUB estimate of OOV size\n";
+ dict->incfreq(oov,dict->dub()-dict->size()+1);
+ } else {
+ cerr << "1 = estimate of OOV size\n";
+ dict->incfreq(oov,1);
+ }
+ }
+ }
+}
+
+double interplm::unigrWB(ngram ng)
+{
+ return
+ ((double)(dict->freq(*ng.wordp(1))+epsilon))/
+ ((double)dict->totfreq() + (double) dict->size() * epsilon);
+}
+
+interplm::interplm(char *ngtfile,int depth,TABLETYPE tabtype):
+ ngramtable(ngtfile,depth,NULL,NULL,NULL,0,0,NULL,0,tabtype)
+{
+
+ if (maxlevel()<depth) {
+ exit_error(IRSTLM_ERROR_DATA, "interplm::interplm ngramtable size is too low");
+ }
+
+ lms=depth;
+ unitbl=NULL;
+ epsilon=1.0;
+ unismooth=1;
+ prune_singletons=0;
+ prune_top_singletons=0;
+
+ init_prune_ngram(lms);
+ print_prune_ngram();
+
+ //doing something nasty: change counter of <s>
+
+ int BoS=dict->encode(dict->BoS());
+ if (BoS != dict->oovcode()) {
+ cerr << "setting counter of Begin of Sentence to 1 ..." << "\n";
+ dict->freq(BoS,1);
+ cerr << "start_sent: " << (char *)dict->decode(BoS) << " "
+ << dict->freq(BoS) << "\n";
+ }
+
+};
+
+interplm::~interplm()
+{
+ delete_prune_ngram();
+}
+
+void interplm::delete_prune_ngram()
+{
+ delete []prune_freq_threshold;
+}
+
+void interplm::init_prune_ngram(int sz)
+{
+ prune_freq_threshold = new int[sz+1];
+ for (int i=0; i<=sz; ++i)
+ {
+ prune_freq_threshold[i] = 0;
+ }
+}
+
+void interplm::print_prune_ngram()
+{
+ for (int i=1; i<=lms; ++i)
+ VERBOSE(0,"level " << i << " prune_freq_threshold[" << i << "]=" << prune_freq_threshold[i] << "\n");
+}
+
+void interplm::set_prune_ngram(char* values)
+{
+ char *s=strdup(values);
+ char *tk;
+
+ prune_freq_threshold[0]=0;
+ int i=1;
+ tk=strtok(s, ",");
+ while (tk)
+ {
+ if (i<=lms)
+ {
+ prune_freq_threshold[i]=atoi(tk);
+ VERBOSE(2,"prune_freq_threshold[" << i << "]=" << prune_freq_threshold[i] << "\n");
+ tk=strtok(NULL, ",");
+ }
+ else
+ {
+ VERBOSE(2,"too many pruning frequency threshold values; kept the first values and skipped the others\n");
+ break;
+ }
+ ++i;
+ }
+
+ for (int i=1; i<=lms; ++i)
+ {
+ if (prune_freq_threshold[i]<prune_freq_threshold[i-1])
+ {
+ prune_freq_threshold[i]=prune_freq_threshold[i-1];
+ VERBOSE(2,"the value of the pruning frequency threshold for level " << i << " has been adjusted to value " << prune_freq_threshold[i] << "\n");
+ }
+ }
+ print_prune_ngram();
+ free(s);
+}
+
+
+void interplm::set_prune_ngram(int lev, int val)
+{
+ if (lev <= lms)
+ {
+ if (val > 0)
+ {
+ prune_freq_threshold[lev] = val;
+ }
+ else
+ {
+ VERBOSE(2,"Value (" << val << ") must be larger than 0\n");
+ }
+ }
+ else
+ {
+ VERBOSE(2,"lev (" << lev << ") is larger than the lm order (" << lms<< ")\n");
+ }
+}
+
+void interplm::gensuccstat()
+{
+
+ ngram hg(dict);
+ int s1,s2;
+
+ cerr << "Generating successor statistics\n";
+
+
+ for (int l=2; l<=lms; l++) {
+
+ cerr << "level " << l << "\n";
+
+ scan(hg,INIT,l-1);
+ while(scan(hg,CONT,l-1)) {
+
+ s1=s2=0;
+
+ ngram ng=hg;
+ ng.pushc(0);
+
+ succscan(hg,ng,INIT,l);
+ while(succscan(hg,ng,CONT,l)) {
+ if (corrcounts && l<lms) //use corrected counts!!!
+ ng.freq=getfreq(ng.link,ng.pinfo,1);
+
+ if (ng.freq==1) s1++;
+ else if (ng.freq==2) s2++;
+ }
+
+ succ2(hg.link,s2);
+ succ1(hg.link,s1);
+ }
+ }
+}
+
+
+void interplm::gencorrcounts()
+{
+ cerr << "Generating corrected n-gram tables\n";
+
+ for (int l=lms-1; l>=1; l--) {
+
+ cerr << "level " << l << "\n";
+
+ ngram ng(dict);
+ int count=0;
+
+ //now update counts
+ scan(ng,INIT,l+1);
+ while(scan(ng,CONT,l+1)) {
+
+ ngram ng2=ng;
+ ng2.size--;
+ if (get(ng2,ng2.size,ng2.size)) {
+
+ if (!ng2.containsWord(dict->BoS(),1))
+ //counts number of different n-grams
+ setfreq(ng2.link,ng2.pinfo,1+getfreq(ng2.link,ng2.pinfo,1),1);
+ else
+ // use correct count for n-gram "<s> w .. .. "
+ //setfreq(ng2.link,ng2.pinfo,ng2.freq+getfreq(ng2.link,ng2.pinfo,1),1);
+ setfreq(ng2.link,ng2.pinfo,ng2.freq,1);
+ } else {
+ MY_ASSERT(lms==l+1);
+ cerr << "cannot find2 " << ng2 << "count " << count << "\n";
+ cerr << "inserting ngram and starting from scratch\n";
+ ng2.pushw(dict->BoS());
+ ng2.freq=100;
+ put(ng2);
+
+ cerr << "reset all counts at last level\n";
+
+ scan(ng2,INIT,lms-1);
+ while(scan(ng2,CONT,lms-1)) {
+ setfreq(ng2.link,ng2.pinfo,0,1);
+ }
+
+ gencorrcounts();
+ return;
+ }
+ }
+ }
+
+ cerr << "Updating history counts\n";
+
+ for (int l=lms-2; l>=1; l--) {
+
+ cerr << "level " << l << "\n";
+
+ cerr << "reset counts\n";
+
+ ngram ng(dict);
+ scan(ng,INIT,l);
+ while(scan(ng,CONT,l)) {
+ freq(ng.link,ng.pinfo,0);
+ }
+
+ scan(ng,INIT,l+1);
+ while(scan(ng,CONT,l+1)) {
+
+ ngram ng2=ng;
+ get(ng2,l+1,l);
+ freq(ng2.link,ng2.pinfo,freq(ng2.link,ng2.pinfo)+getfreq(ng.link,ng.pinfo,1));
+ }
+ }
+
+ cerr << "Adding unigram of OOV word if missing\n";
+ ngram ng(dict,maxlevel());
+ for (int i=1; i<=maxlevel(); i++)
+ *ng.wordp(i)=dict->oovcode();
+
+ if (!get(ng,lms,1)) {
+ // oov is missing in the ngram-table
+ // f(oov) = dictionary size (Witten Bell) (excluding oov itself)
+ ng.freq=dict->size()-1;
+ cerr << "adding oov unigram |" << ng << "| with frequency " << ng.freq << "\n";
+ put(ng);
+ get(ng,lms,1);
+ setfreq(ng.link,ng.pinfo,ng.freq,1);
+ }
+
+ cerr << "Replacing unigram of BoS \n";
+ if (dict->encode(dict->BoS()) != dict->oovcode()) {
+ ngram ng(dict,1);
+ *ng.wordp(1)=dict->encode(dict->BoS());
+
+ if (get(ng,1,1)) {
+ ng.freq=1; //putting Pr(<s>)=0 would create problems!!
+ setfreq(ng.link,ng.pinfo,ng.freq,1);
+ }
+ }
+
+ cerr << "compute unigram totfreq \n";
+ int totf=0;
+ scan(ng,INIT,1);
+ while(scan(ng,CONT,1)) {
+ totf+=getfreq(ng.link,ng.pinfo,1);
+ }
+
+ btotfreq(totf);
+ cerr << "compute unigram btotfreq(totf):" << btotfreq() << "\n";
+
+ corrcounts=1;
+}
+
+double interplm::zerofreq(int lev)
+{
+ cerr << "Computing lambda: ...";
+ ngram ng(dict);
+ double N=0,N1=0;
+ scan(ng,INIT,lev);
+ while(scan(ng,CONT,lev)) {
+ if ((lev==1) && (*ng.wordp(1)==dict->oovcode()))
+ continue;
+ N+=ng.freq;
+ if (ng.freq==1) N1++;
+ }
+ cerr << (double)(N1/N) << "\n";
+ return N1/N;
+}
+
+
+void interplm::test(char* filename,int size,bool backoff,bool checkpr,char* outpr)
+{
+
+ if (size>lmsize()) {
+ exit_error(IRSTLM_ERROR_DATA, "interplm::test: wrong ngram size");
+ }
+
+
+ mfstream inp(filename,ios::in );
+
+ char header[100];
+ inp >> header;
+ inp.close();
+
+ if (strncmp(header,"nGrAm",5)==0 ||
+ strncmp(header,"NgRaM",5)==0) {
+ ngramtable ngt(filename,size,NULL,NULL,NULL,0,0,NULL,0,COUNT);
+ test_ngt(ngt,size,backoff,checkpr);
+ } else
+ test_txt(filename,size,backoff,checkpr,outpr);
+}
+
+
+void interplm::test_txt(char* filename,int size,bool /* unused parameter: backoff */,bool checkpr,char* outpr)
+{
+
+ cerr << "test text " << filename << " ";
+ mfstream inp(filename,ios::in );
+ ngram ng(dict);
+
+ double n=0,lp=0,pr;
+ double oov=0;
+ cout.precision(10);
+ mfstream outp(outpr?outpr:"/dev/null",ios::out );
+
+ if (checkpr)
+ cerr << "checking probabilities\n";
+
+ while(inp >> ng)
+ if (ng.size>=1) {
+
+ ng.size=ng.size>size?size:ng.size;
+
+ if (dict->encode(dict->BoS()) != dict->oovcode()) {
+ if (*ng.wordp(1) == dict->encode(dict->BoS())) {
+ ng.size=1; //reset n-grams starting with BoS
+ continue;
+ }
+ }
+
+ pr=prob(ng,ng.size);
+
+ if (outpr)
+ outp << ng << "[" << ng.size << "-gram]" << " " << pr << " " << log(pr)/log(10.0) << std::endl;
+
+ lp-=log(pr);
+
+ n++;
+
+ if (((int) n % 10000)==0) cerr << ".";
+
+ if (*ng.wordp(1) == dict->oovcode()) oov++;
+
+ if (checkpr) {
+ double totp=0.0;
+ int oldw=*ng.wordp(1);
+ for (int c=0; c<dict->size(); c++) {
+ *ng.wordp(1)=c;
+ totp+=prob(ng,ng.size);
+ }
+ *ng.wordp(1)=oldw;
+
+ if ( totp < (1.0 - 1e-5) || totp > (1.0 + 1e-5))
+ cout << ng << " " << pr << " [t="<< totp << "] ***\n";
+ }
+
+ }
+
+ if (oov && dict->dub()>obswrd())
+ lp += oov * log(dict->dub() - obswrd());
+
+ cout << "n=" << (int) n << " LP="
+ << (double) lp
+ << " PP=" << exp(lp/n)
+ << " OVVRate=" << (oov)/n
+ //<< " OVVLEXRate=" << (oov-in_oov_list)/n
+ // << " OOVPP=" << exp((lp+oovlp)/n)
+
+ << "\n";
+
+
+ outp.close();
+ inp.close();
+}
+
+
+void interplm::test_ngt(ngramtable& ngt,int sz,bool /* unused parameter: backoff */,bool checkpr)
+{
+
+ double pr;
+ int n=0,c=0;
+ double lp=0;
+ double oov=0;
+ cout.precision(10);
+
+ if (sz > ngt.maxlevel()) {
+ exit_error(IRSTLM_ERROR_DATA, "interplm::test_ngt: ngramtable has uncompatible size");
+ }
+
+ if (checkpr) cerr << "checking probabilities\n";
+
+ cerr << "Computing PP:";
+
+ ngram ng(dict);
+ ngram ng2(ngt.dict);
+ ngt.scan(ng2,INIT,sz);
+
+ while(ngt.scan(ng2,CONT,sz)) {
+
+ ng.trans(ng2);
+
+ if (dict->encode(dict->BoS()) != dict->oovcode()) {
+ if (*ng.wordp(1) == dict->encode(dict->BoS())) {
+ ng.size=1; //reset n-grams starting with BoS
+ continue;
+ }
+ }
+
+ n+=ng.freq;
+ pr=prob(ng,sz);
+
+ lp-=(ng.freq * log(pr));
+
+ if (*ng.wordp(1) == dict->oovcode())
+ oov+=ng.freq;
+
+
+ if (checkpr) {
+ double totp=0.0;
+ for (c=0; c<dict->size(); c++) {
+ *ng.wordp(1)=c;
+ totp+=prob(ng,sz);
+ }
+
+ if ( totp < (1.0 - 1e-5) ||
+ totp > (1.0 + 1e-5))
+ cout << ng << " " << pr << " [t="<< totp << "] ***\n";
+
+ }
+
+ if ((++c % 100000)==0) cerr << ".";
+
+ }
+
+ //double oovlp=oov * log((double)(dict->dub() - obswrd()));
+
+
+ if (oov && dict->dub()>obswrd())
+
+ lp+=oov * log((dict->dub() - obswrd()));
+
+ cout << "n=" << (int) n << " LP="
+ << (double) lp
+ << " PP=" << exp(lp/n)
+ << " OVVRate=" << (oov)/n
+ //<< " OVVLEXRate=" << (oov-in_oov_list)/n
+ // << " OOVPP=" << exp((lp+oovlp)/n)
+
+ << "\n";
+
+ cout.flush();
+
+}
+
+
+/*
+main(int argc, char** argv){
+ dictionary d(argv[1]);
+
+ shiftbeta ilm(&d,argv[2],3);
+
+ ngramtable test(&d,argv[2],3);
+ ilm.train();
+ cerr << "PP " << ilm.test(test) << "\n";
+
+ ilm.savebin("newlm.lm",3);
+}
+*/
diff --git a/src/interplm.h b/src/interplm.h
new file mode 100644
index 0000000..2093f15
--- /dev/null
+++ b/src/interplm.h
@@ -0,0 +1,158 @@
+/******************************************************************************
+IrstLM: IRST Language Model Toolkit
+Copyright (C) 2006 Marcello Federico, ITC-irst Trento, Italy
+
+This library is free software; you can redistribute it and/or
+modify it under the terms of the GNU Lesser General Public
+License as published by the Free Software Foundation; either
+version 2.1 of the License, or (at your option) any later version.
+
+This library is distributed in the hope that it will be useful,
+but WITHOUT ANY WARRANTY; without even the implied warranty of
+MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+Lesser General Public License for more details.
+
+You should have received a copy of the GNU Lesser General Public
+License along with this library; if not, write to the Free Software
+Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+
+******************************************************************************/
+// Basic Interpolated LM class
+
+#ifndef MF_INTERPLM_H
+#define MF_INTERPLM_H
+
+#define SHIFT_BETA 1
+#define SHIFT_ONE 2
+#define SHIFT_ZERO 3
+#define LINEAR_STB 4
+#define LINEAR_WB 5
+#define LINEAR_GT 6
+#define MIXTURE 7
+#define MOD_SHIFT_BETA 8
+#define IMPROVED_SHIFT_BETA 9
+#define KNESER_NEY 10
+#define IMPROVED_KNESER_NEY 11
+
+class interplm:public ngramtable
+{
+
+ int lms;
+
+ double epsilon; //Bayes smoothing
+
+ int unismooth; //0 Bayes, 1 Witten Bell
+
+ int prune_singletons;
+ int prune_top_singletons;
+ int* prune_freq_threshold;
+
+public:
+
+ int backoff; //0 interpolation, 1 Back-off
+
+ interplm(char* ngtfile,int depth=0,TABLETYPE tt=FULL);
+ virtual ~interplm();
+
+ int prunesingletons(int flag=-1) {
+ return (flag==-1?prune_singletons:prune_singletons=flag);
+ }
+
+ int prunetopsingletons(int flag=-1) {
+ return (flag==-1?prune_top_singletons:prune_top_singletons=flag);
+ }
+
+ inline bool prune_ngram(int lev, int freq)
+ {
+ return (freq > prune_freq_threshold[lev])?false:true;
+ }
+
+ void init_prune_ngram(int sz);
+ void delete_prune_ngram();
+ void set_prune_ngram(int lev, int val);
+ void set_prune_ngram(char* values);
+ void print_prune_ngram();
+
+ void gencorrcounts();
+
+ void gensuccstat();
+
+ virtual int dub() {
+ return dict->dub();
+ }
+
+ virtual int dub(int value) {
+ return dict->dub(value);
+ }
+
+ int setusmooth(int v=0) {
+ return unismooth=v;
+ }
+
+ double setepsilon(double v=1.0) {
+ return epsilon=v;
+ }
+
+ ngramtable *unitbl;
+
+ void trainunigr();
+
+ double unigrWB(ngram ng);
+ virtual double unigr(ngram ng){ return unigrWB(ng); };
+
+ double zerofreq(int lev);
+
+ inline int lmsize() const {
+ return lms;
+ }
+
+ inline int obswrd() const {
+ return dict->size();
+ }
+
+ virtual int train() {
+ return 0;
+ }
+
+ virtual void adapt(char* /* unused parameter: ngtfile */, double /* unused parameter: w */) {}
+
+ virtual double prob(ngram /* unused parameter: ng */,int /* unused parameter: size */) {
+ return 0.0;
+ }
+
+ virtual double boprob(ngram /* unused parameter: ng */,int /* unused parameter: size */) {
+ return 0.0;
+ }
+
+ void test_ngt(ngramtable& ngt,int sz=0,bool backoff=false,bool checkpr=false);
+
+ void test_txt(char *filename,int sz=0,bool backoff=false,bool checkpr=false,char* outpr=NULL);
+
+ void test(char* filename,int sz,bool backoff=false,bool checkpr=false,char* outpr=NULL);
+
+ virtual int discount(ngram /* unused parameter: ng */,int /* unused parameter: size */,double& /* unused parameter: fstar */ ,double& /* unused parameter: lambda */,int /* unused parameter: cv*/=0) {
+ return 0;
+ }
+
+ virtual int savebin(char* /* unused parameter: filename */,int /* unused parameter: lmsize=2 */) {
+ return 0;
+ }
+
+ virtual int netsize() {
+ return 0;
+ }
+
+ void lmstat(int level) {
+ stat(level);
+ }
+
+
+};
+
+#endif
+
+
+
+
+
+
diff --git a/src/interpolate-lm.cpp b/src/interpolate-lm.cpp
new file mode 100644
index 0000000..d3071b3
--- /dev/null
+++ b/src/interpolate-lm.cpp
@@ -0,0 +1,564 @@
+/******************************************************************************
+ IrstLM: IRST Language Model Toolkit, compile LM
+ Copyright (C) 2006 Marcello Federico, ITC-irst Trento, Italy
+
+ This library is free software; you can redistribute it and/or
+ modify it under the terms of the GNU Lesser General Public
+ License as published by the Free Software Foundation; either
+ version 2.1 of the License, or (at your option) any later version.
+
+ This library is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+ Lesser General Public License for more details.
+
+ You should have received a copy of the GNU Lesser General Public
+ License along with this library; if not, write to the Free Software
+ Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+
+ ******************************************************************************/
+
+#include <iostream>
+#include <fstream>
+#include <sstream>
+#include <stdexcept>
+#include <vector>
+#include <string>
+#include <stdlib.h>
+#include "cmd.h"
+#include "util.h"
+#include "math.h"
+#include "lmContainer.h"
+
+#define MAX_N 100
+/********************************/
+using namespace std;
+using namespace irstlm;
+
+inline void error(const char* message)
+{
+ std::cerr << message << "\n";
+ throw std::runtime_error(message);
+}
+
+lmContainer* load_lm(std::string file,int requiredMaxlev,int dub,int memmap, float nlf, float dlf);
+
+void print_help(int TypeFlag=0){
+ std::cerr << std::endl << "interpolate-lm - interpolates language models" << std::endl;
+ std::cerr << std::endl << "USAGE:" << std::endl;
+ std::cerr << " interpolate-lm [options] <lm-list-file> [lm-list-file.out]" << std::endl;
+
+ std::cerr << std::endl << "DESCRIPTION:" << std::endl;
+ std::cerr << " interpolate-lm reads a LM list file including interpolation weights " << std::endl;
+ std::cerr << " with the format: N\\n w1 lm1 \\n w2 lm2 ...\\n wN lmN\n" << std::endl;
+ std::cerr << " It estimates new weights on a development text, " << std::endl;
+ std::cerr << " computes the perplexity on an evaluation text, " << std::endl;
+ std::cerr << " computes probabilities of n-grams read from stdin." << std::endl;
+ std::cerr << " It reads LMs in ARPA and IRSTLM binary format." << std::endl;
+
+ std::cerr << std::endl << "OPTIONS:" << std::endl;
+ FullPrintParams(TypeFlag, 0, 1, stderr);
+
+}
+
+void usage(const char *msg = 0)
+{
+ if (msg){
+ std::cerr << msg << std::endl;
+ }
+ else{
+ print_help();
+ }
+}
+
+int main(int argc, char **argv)
+{
+ char *slearn = NULL;
+ char *seval = NULL;
+ bool learn=false;
+ bool score=false;
+ bool sent_PP_flag = false;
+
+ int order = 0;
+ int debug = 0;
+ int memmap = 0;
+ int requiredMaxlev = 1000;
+ int dub = 10000000;
+ float ngramcache_load_factor = 0.0;
+ float dictionary_load_factor = 0.0;
+
+ bool help=false;
+ std::vector<std::string> files;
+
+ DeclareParams((char*)
+
+ "learn", CMDSTRINGTYPE|CMDMSG, &slearn, "learn optimal interpolation for text-file; default is false",
+ "l", CMDSTRINGTYPE|CMDMSG, &slearn, "learn optimal interpolation for text-file; default is false",
+ "order", CMDINTTYPE|CMDMSG, &order, "order of n-grams used in --learn (optional)",
+ "o", CMDINTTYPE|CMDMSG, &order, "order of n-grams used in --learn (optional)",
+ "eval", CMDSTRINGTYPE|CMDMSG, &seval, "computes perplexity of the specified text file",
+ "e", CMDSTRINGTYPE|CMDMSG, &seval, "computes perplexity of the specified text file",
+
+ "DictionaryUpperBound", CMDINTTYPE|CMDMSG, &dub, "dictionary upperbound to compute OOV word penalty: default 10^7",
+ "dub", CMDINTTYPE|CMDMSG, &dub, "dictionary upperbound to compute OOV word penalty: default 10^7",
+ "score", CMDBOOLTYPE|CMDMSG, &score, "computes log-prob scores of n-grams from standard input",
+ "s", CMDBOOLTYPE|CMDMSG, &score, "computes log-prob scores of n-grams from standard input",
+
+ "debug", CMDINTTYPE|CMDMSG, &debug, "verbose output for --eval option; default is 0",
+ "d", CMDINTTYPE|CMDMSG, &debug, "verbose output for --eval option; default is 0",
+ "memmap", CMDINTTYPE|CMDMSG, &memmap, "uses memory map to read a binary LM",
+ "mm", CMDINTTYPE|CMDMSG, &memmap, "uses memory map to read a binary LM",
+ "sentence", CMDBOOLTYPE|CMDMSG, &sent_PP_flag, "computes perplexity at sentence level (identified through the end symbol)",
+ "dict_load_factor", CMDFLOATTYPE|CMDMSG, &dictionary_load_factor, "sets the load factor for ngram cache; it should be a positive real value; default is 0",
+ "ngram_load_factor", CMDFLOATTYPE|CMDMSG, &ngramcache_load_factor, "sets the load factor for ngram cache; it should be a positive real value; default is false",
+ "level", CMDINTTYPE|CMDMSG, &requiredMaxlev, "maximum level to load from the LM; if value is larger than the actual LM order, the latter is taken",
+ "lev", CMDINTTYPE|CMDMSG, &requiredMaxlev, "maximum level to load from the LM; if value is larger than the actual LM order, the latter is taken",
+
+ "Help", CMDBOOLTYPE|CMDMSG, &help, "print this help",
+ "h", CMDBOOLTYPE|CMDMSG, &help, "print this help",
+
+ (char *)NULL
+ );
+
+ if (argc == 1){
+ usage();
+ exit_error(IRSTLM_NO_ERROR);
+ }
+
+ for(int i=1; i < argc; i++) {
+ if(argv[i][0] != '-') files.push_back(argv[i]);
+ }
+
+ GetParams(&argc, &argv, (char*) NULL);
+
+ if (help){
+ usage();
+ exit_error(IRSTLM_NO_ERROR);
+ }
+
+ if (files.size() > 2) {
+ usage();
+ exit_error(IRSTLM_ERROR_DATA,"Too many arguments");
+ }
+
+ if (files.size() < 1) {
+ usage();
+ exit_error(IRSTLM_ERROR_DATA,"Must pecify a LM list file to read from");
+ }
+
+ std::string infile = files[0];
+ std::string outfile="";
+
+ if (files.size() == 1) {
+ outfile=infile;
+ //remove path information
+ std::string::size_type p = outfile.rfind('/');
+ if (p != std::string::npos && ((p+1) < outfile.size()))
+ outfile.erase(0,p+1);
+ outfile+=".out";
+ } else
+ outfile = files[1];
+
+ std::cerr << "inpfile: " << infile << std::endl;
+ learn = ((slearn != NULL)? true : false);
+
+ if (learn) std::cerr << "outfile: " << outfile << std::endl;
+ if (score) std::cerr << "interactive: " << score << std::endl;
+ if (memmap) std::cerr << "memory mapping: " << memmap << std::endl;
+ std::cerr << "loading up to the LM level " << requiredMaxlev << " (if any)" << std::endl;
+ std::cerr << "order: " << order << std::endl;
+ if (requiredMaxlev > 0) std::cerr << "loading up to the LM level " << requiredMaxlev << " (if any)" << std::endl;
+
+ std::cerr << "dub: " << dub<< std::endl;
+
+ lmContainer *lmt[MAX_N], *start_lmt[MAX_N]; //interpolated language models
+ std::string lmf[MAX_N]; //lm filenames
+
+ float w[MAX_N]; //interpolation weights
+ int N;
+
+
+ //Loading Language Models`
+ std::cerr << "Reading " << infile << "..." << std::endl;
+ std::fstream inptxt(infile.c_str(),std::ios::in);
+
+ //std::string line;
+ char line[BUFSIZ];
+ const char* words[3];
+ int tokenN;
+
+ inptxt.getline(line,BUFSIZ,'\n');
+ tokenN = parseWords(line,words,3);
+
+ if (tokenN != 2 || ((strcmp(words[0],"LMINTERPOLATION") != 0) && (strcmp(words[0],"lminterpolation")!=0)))
+ error((char*)"ERROR: wrong header format of configuration file\ncorrect format: LMINTERPOLATION number_of_models\nweight_of_LM_1 filename_of_LM_1\nweight_of_LM_2 filename_of_LM_2");
+
+ N=atoi(words[1]);
+ std::cerr << "Number of LMs: " << N << "..." << std::endl;
+ if(N > MAX_N) {
+ exit_error(IRSTLM_ERROR_DATA,"Can't interpolate more than MAX_N language models");
+
+ }
+
+ for (int i=0; i<N; i++) {
+ inptxt.getline(line,BUFSIZ,'\n');
+ tokenN = parseWords(line,words,3);
+ if(tokenN != 2) {
+ exit_error(IRSTLM_ERROR_DATA,"Wrong input format");
+ }
+ w[i] = (float) atof(words[0]);
+ lmf[i] = words[1];
+
+ std::cerr << "i:" << i << " w[i]:" << w[i] << " lmf[i]:" << lmf[i] << std::endl;
+ start_lmt[i] = lmt[i] = load_lm(lmf[i],requiredMaxlev,dub,memmap,ngramcache_load_factor,dictionary_load_factor);
+ }
+
+ inptxt.close();
+
+ int maxorder = 0;
+ for (int i=0; i<N; i++) {
+ maxorder = (maxorder > lmt[i]->maxlevel())?maxorder:lmt[i]->maxlevel();
+ }
+
+ if (order <= 0) {
+ order = maxorder;
+ std::cerr << "order is not set or wrongly set to a non positive value; reset to the maximum order of LMs: " << order << std::endl;
+ } else if (order > maxorder) {
+ order = maxorder;
+ std::cerr << "order is too high; reset to the maximum order of LMs" << order << std::endl;
+ }
+
+ //Learning mixture weights
+ if (learn) {
+ std::vector<float> *p = new std::vector<float>[N]; //LM probabilities
+ float c[N]; //expected counts
+ float den,norm; //inner denominator, normalization term
+ float variation=1.0; // global variation between new old params
+
+ dictionary* dict=new dictionary(slearn,1000000,dictionary_load_factor);
+ ngram ng(dict);
+ int bos=ng.dict->encode(ng.dict->BoS());
+ std::ifstream dev(slearn,std::ios::in);
+
+ for(;;) {
+ std::string line;
+ getline(dev, line);
+ if(dev.eof())
+ break;
+ if(dev.fail()) {
+ exit_error(IRSTLM_ERROR_IO,"Problem reading input file");
+ }
+ std::istringstream lstream(line);
+ if(line.substr(0, 29) == "###interpolate-lm:replace-lm ") {
+ std::string token, newlm;
+ int id;
+ lstream >> token >> id >> newlm;
+ if(id <= 0 || id > N) {
+ std::cerr << "LM id out of range." << std::endl;
+ delete[] p;
+ return 1;
+ }
+ id--; // count from 0 now
+ if(lmt[id] != start_lmt[id])
+ delete lmt[id];
+ lmt[id] = load_lm(newlm,requiredMaxlev,dub,memmap,ngramcache_load_factor,dictionary_load_factor);
+ continue;
+ }
+ while(lstream >> ng) {
+
+ // reset ngram at begin of sentence
+ if (*ng.wordp(1)==bos) {
+ ng.size=1;
+ continue;
+ }
+ if (order > 0 && ng.size > order) ng.size=order;
+ for (int i=0; i<N; i++) {
+ ngram ong(lmt[i]->getDict());
+ ong.trans(ng);
+ double logpr;
+ logpr = lmt[i]->clprob(ong); //LM log-prob (using caches if available)
+ p[i].push_back(pow(10.0,logpr));
+ }
+ }
+
+ for (int i=0; i<N; i++) lmt[i]->check_caches_levels();
+ }
+ dev.close();
+
+ while( variation > 0.01 ) {
+
+ for (int i=0; i<N; i++) c[i]=0; //reset counters
+
+ for(unsigned i = 0; i < p[0].size(); i++) {
+ den=0.0;
+ for(int j = 0; j < N; j++)
+ den += w[j] * p[j][i]; //denominator of EM formula
+ //update expected counts
+ for(int j = 0; j < N; j++)
+ c[j] += w[j] * p[j][i] / den;
+ }
+
+ norm=0.0;
+ for (int i=0; i<N; i++) norm+=c[i];
+
+ //update weights and compute distance
+ variation=0.0;
+ for (int i=0; i<N; i++) {
+ c[i]/=norm; //c[i] is now the new weight
+ variation+=(w[i]>c[i]?(w[i]-c[i]):(c[i]-w[i]));
+ w[i]=c[i]; //update weights
+ }
+ std::cerr << "Variation " << variation << std::endl;
+ }
+
+ //Saving results
+ std::cerr << "Saving in " << outfile << "..." << std::endl;
+ std::fstream outtxt(outfile.c_str(),std::ios::out);
+ outtxt << "LMINTERPOLATION " << N << "\n";
+ for (int i=0; i<N; i++) outtxt << w[i] << " " << lmf[i] << "\n";
+ outtxt.close();
+ delete[] p;
+ }
+
+ for(int i = 0; i < N; i++)
+ if(lmt[i] != start_lmt[i]) {
+ delete lmt[i];
+ lmt[i] = start_lmt[i];
+ }
+
+ if (seval != NULL) {
+ std::cerr << "Start Eval" << std::endl;
+
+ std::cout.setf(ios::fixed);
+ std::cout.precision(2);
+ int i;
+ int Nw=0,Noov_all=0, Noov_any=0, Nbo=0;
+ double Pr,lPr;
+ double logPr=0,PP=0;
+
+ // variables for storing sentence-based Perplexity
+ int sent_Nw=0, sent_Noov_all=0, sent_Noov_any=0, sent_Nbo=0;
+ double sent_logPr=0,sent_PP=0;
+
+ //normalize weights
+ for (i=0,Pr=0; i<N; i++) Pr+=w[i];
+ for (i=0; i<N; i++) w[i]/=Pr;
+
+ dictionary* dict=new dictionary(NULL,1000000,dictionary_load_factor);
+ dict->incflag(1);
+ ngram ng(dict);
+ int bos=ng.dict->encode(ng.dict->BoS());
+ int eos=ng.dict->encode(ng.dict->EoS());
+
+ std::fstream inptxt(seval,std::ios::in);
+
+ for(;;) {
+ std::string line;
+ getline(inptxt, line);
+ if(inptxt.eof())
+ break;
+ if(inptxt.fail()) {
+ std::cerr << "Problem reading input file " << seval << std::endl;
+ return 1;
+ }
+ std::istringstream lstream(line);
+ if(line.substr(0, 26) == "###interpolate-lm:weights ") {
+ std::string token;
+ lstream >> token;
+ for(int i = 0; i < N; i++) {
+ if(lstream.eof()) {
+ std::cerr << "Not enough weights!" << std::endl;
+ return 1;
+ }
+ lstream >> w[i];
+ }
+ continue;
+ }
+ if(line.substr(0, 29) == "###interpolate-lm:replace-lm ") {
+ std::string token, newlm;
+ int id;
+ lstream >> token >> id >> newlm;
+ if(id <= 0 || id > N) {
+ std::cerr << "LM id out of range." << std::endl;
+ return 1;
+ }
+ id--; // count from 0 now
+ delete lmt[id];
+ lmt[id] = load_lm(newlm,requiredMaxlev,dub,memmap,ngramcache_load_factor,dictionary_load_factor);
+ continue;
+ }
+
+ double bow;
+ int bol=0;
+ char *msp;
+ unsigned int statesize;
+
+ while(lstream >> ng) {
+
+ // reset ngram at begin of sentence
+ if (*ng.wordp(1)==bos) {
+ ng.size=1;
+ continue;
+ }
+ if (order > 0 && ng.size > order) ng.size=order;
+
+
+ if (ng.size>=1) {
+
+ int minbol=MAX_NGRAM; //minimum backoff level of the mixture
+ bool OOV_all_flag=true; //OOV flag wrt all LM[i]
+ bool OOV_any_flag=false; //OOV flag wrt any LM[i]
+ float logpr;
+
+ Pr = 0.0;
+ for (i=0; i<N; i++) {
+
+ ngram ong(lmt[i]->getDict());
+ ong.trans(ng);
+ logpr = lmt[i]->clprob(ong,&bow,&bol,&msp,&statesize); //actual prob of the interpolation
+ //logpr = lmt[i]->clprob(ong,&bow,&bol); //LM log-prob
+
+ Pr+=w[i] * pow(10.0,logpr); //actual prob of the interpolation
+ if (bol < minbol) minbol=bol; //backoff of LM[i]
+
+ if (*ong.wordp(1) != lmt[i]->getDict()->oovcode()) OOV_all_flag=false; //OOV wrt LM[i]
+ if (*ong.wordp(1) == lmt[i]->getDict()->oovcode()) OOV_any_flag=true; //OOV wrt LM[i]
+ }
+
+ lPr=log(Pr)/M_LN10;
+ logPr+=lPr;
+ sent_logPr+=lPr;
+
+ if (debug==1) {
+ std::cout << ng.dict->decode(*ng.wordp(1)) << " [" << ng.size-minbol << "]" << " ";
+ if (*ng.wordp(1)==eos) std::cout << std::endl;
+ }
+ if (debug==2)
+ std::cout << ng << " [" << ng.size-minbol << "-gram]" << " " << log(Pr) << std::endl;
+
+ if (minbol) {
+ Nbo++; //all LMs have back-offed by at least one
+ sent_Nbo++;
+ }
+
+ if (OOV_all_flag) {
+ Noov_all++; //word is OOV wrt all LM
+ sent_Noov_all++;
+ }
+ if (OOV_any_flag) {
+ Noov_any++; //word is OOV wrt any LM
+ sent_Noov_any++;
+ }
+
+ Nw++;
+ sent_Nw++;
+
+ if (*ng.wordp(1)==eos && sent_PP_flag) {
+ sent_PP=exp((-sent_logPr * log(10.0)) /sent_Nw);
+ std::cout << "%% sent_Nw=" << sent_Nw
+ << " sent_PP=" << sent_PP
+ << " sent_Nbo=" << sent_Nbo
+ << " sent_Noov=" << sent_Noov_all
+ << " sent_OOV=" << (float)sent_Noov_all/sent_Nw * 100.0 << "%"
+ << " sent_Noov_any=" << sent_Noov_any
+ << " sent_OOV_any=" << (float)sent_Noov_any/sent_Nw * 100.0 << "%" << std::endl;
+ //reset statistics for sentence based Perplexity
+ sent_Nw=sent_Noov_any=sent_Noov_all=sent_Nbo=0;
+ sent_logPr=0.0;
+ }
+
+
+ if ((Nw % 10000)==0) std::cerr << ".";
+ }
+ }
+ }
+
+ PP=exp((-logPr * M_LN10) /Nw);
+
+ std::cout << "%% Nw=" << Nw
+ << " PP=" << PP
+ << " Nbo=" << Nbo
+ << " Noov=" << Noov_all
+ << " OOV=" << (float)Noov_all/Nw * 100.0 << "%"
+ << " Noov_any=" << Noov_any
+ << " OOV_any=" << (float)Noov_any/Nw * 100.0 << "%" << std::endl;
+
+ };
+
+
+ if (score == true) {
+
+
+ dictionary* dict=new dictionary(NULL,1000000,dictionary_load_factor);
+ dict->incflag(1); // start generating the dictionary;
+ ngram ng(dict);
+ int bos=ng.dict->encode(ng.dict->BoS());
+
+ double Pr,logpr;
+
+ double bow;
+ int bol=0, maxbol=0;
+ unsigned int maxstatesize, statesize;
+ int i,n=0;
+ std::cout << "> ";
+ while(std::cin >> ng) {
+
+ // reset ngram at begin of sentence
+ if (*ng.wordp(1)==bos) {
+ ng.size=1;
+ continue;
+ }
+
+ if (ng.size>=maxorder) {
+
+ if (order > 0 && ng.size > order) ng.size=order;
+ n++;
+ maxstatesize=0;
+ maxbol=0;
+ Pr=0.0;
+ for (i=0; i<N; i++) {
+ ngram ong(lmt[i]->getDict());
+ ong.trans(ng);
+ logpr = lmt[i]->clprob(ong,&bow,&bol,NULL,&statesize); //LM log-prob (using caches if available)
+
+ Pr+=w[i] * pow(10.0,logpr); //actual prob of the interpolation
+ std::cout << "lm " << i << ":" << " logpr: " << logpr << " weight: " << w[i] << std::endl;
+ if (maxbol<bol) maxbol=bol;
+ if (maxstatesize<statesize) maxstatesize=statesize;
+ }
+
+ std::cout << ng << " p= " << log(Pr) << " bo= " << maxbol << " recombine= " << maxstatesize << std::endl;
+
+ if ((n % 10000000)==0) {
+ std::cerr << "." << std::endl;
+ for (i=0; i<N; i++) lmt[i]->check_caches_levels();
+ }
+
+ } else {
+ std::cout << ng << " p= NULL" << std::endl;
+ }
+ std::cout << "> ";
+ }
+
+
+ }
+
+ for (int i=0; i<N; i++) delete lmt[i];
+
+ return 0;
+}
+
+lmContainer* load_lm(std::string file,int requiredMaxlev,int dub,int memmap, float nlf, float dlf)
+{
+ lmContainer* lmt = lmContainer::CreateLanguageModel(file,nlf,dlf);
+
+ lmt->setMaxLoadedLevel(requiredMaxlev);
+
+ lmt->load(file,memmap);
+
+ if (dub) lmt->setlogOOVpenalty((int)dub);
+
+ //use caches to save time (only if PS_CACHE_ENABLE is defined through compilation flags)
+ lmt->init_caches(lmt->maxlevel());
+ return lmt;
+}
diff --git a/src/linearlm.cpp b/src/linearlm.cpp
new file mode 100644
index 0000000..2e9b2f8
--- /dev/null
+++ b/src/linearlm.cpp
@@ -0,0 +1,233 @@
+/******************************************************************************
+ IrstLM: IRST Language Model Toolkit
+ Copyright (C) 2006 Marcello Federico, ITC-irst Trento, Italy
+
+ This library is free software; you can redistribute it and/or
+ modify it under the terms of the GNU Lesser General Public
+ License as published by the Free Software Foundation; either
+ version 2.1 of the License, or (at your option) any later version.
+
+ This library is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+ Lesser General Public License for more details.
+
+ You should have received a copy of the GNU Lesser General Public
+ License along with this library; if not, write to the Free Software
+ Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+
+ ******************************************************************************/
+
+#include <string.h>
+#include <stdio.h>
+#include <stdlib.h>
+#include <math.h>
+#include "mfstream.h"
+#include "mempool.h"
+#include "htable.h"
+#include "dictionary.h"
+#include "n_gram.h"
+#include "mempool.h"
+#include "ngramtable.h"
+#include "ngramcache.h"
+#include "normcache.h"
+#include "interplm.h"
+#include "mdiadapt.h"
+#include "linearlm.h"
+#include "util.h"
+
+namespace irstlm {
+ //
+ //Linear interpolated language model: Witten & Bell discounting scheme
+ //
+
+
+ linearwb::linearwb(char* ngtfile,int depth,int prunefreq,TABLETYPE tt):
+ mdiadaptlm(ngtfile,depth,tt)
+ {
+ prunethresh=prunefreq;
+ cerr << "PruneThresh: " << prunethresh << "\n";
+
+ };
+
+
+ int linearwb::train()
+ {
+ trainunigr();
+
+ gensuccstat();
+
+ return 1;
+ }
+
+
+ int linearwb::discount(ngram ng_,int size,double& fstar,double& lambda,int cv)
+ {
+ VERBOSE(3,"linearwb::discount(ngram ng_,int size,double& fstar,double& lambda,int cv) ng_:|" << ng_ << "| size:" << size << " cv:" << cv<< std::endl);
+ ngram ng(dict);
+ ng.trans(ng_);
+
+ if (size > 1) {
+ ngram history=ng;
+
+ if (ng.ckhisto(size) && get(history,size,size-1) && (history.freq>cv) &&
+ ((size < 3) || ((history.freq-cv) > prunethresh))) {
+ // apply history pruning on trigrams only
+
+ if (get(ng,size,size) && (!prunesingletons() || ng.freq>1 || size<3)) {
+ // apply frequency pruning on trigrams only
+
+ cv=(ng.freq<cv)?ng.freq:cv; //hence, ng.freq>=cv
+
+ if (ng.freq>cv) {
+ fstar=(double)(ng.freq-cv)/(double)(history.freq - cv + history.succ);
+ lambda=(double)history.succ/(double)(history.freq - cv + history.succ);
+
+ if (size>=3 && prunesingletons()){ // correction due to frequency pruning
+ lambda+=(double)succ1(history.link)/(double)(history.freq - cv + history.succ);
+ // succ1(history.link) is not affected when ng.freq > cv
+ }
+ } else { // ng.freq == cv
+ fstar=0.0;
+ lambda=(double)(history.succ - 1)/ (double)(history.freq - cv + history.succ - 1); // remove cv n-grams from data
+
+ if (size>=3 && prunesingletons()){ // correction due to frequency pruning
+ lambda+=(double)succ1(history.link)-(cv==1 && ng.freq==1?1:0)/(double)(history.freq - cv + history.succ - 1);
+ }
+
+ }
+ } else {
+ fstar=0.0;
+ lambda=(double)history.succ/(double)(history.freq + history.succ);
+
+ if (size>=3 && prunesingletons()){ // correction due to frequency pruning
+ lambda+=(double)succ1(history.link)/(double)(history.freq + history.succ);
+ }
+ }
+
+ //cerr << "ngram :" << ng << "\n";
+ // if current word is OOV then back-off to unigrams!
+
+ if (*ng.wordp(1)==dict->oovcode()) {
+ lambda+=fstar;
+ fstar=0.0;
+ MY_ASSERT(lambda<=1 && lambda>0);
+ } else { // add f*(oov|...) to lambda
+ *ng.wordp(1)=dict->oovcode();
+ if (get(ng,size,size) && (!prunesingletons() || ng.freq>1 || size<3)){
+ lambda+=(double)ng.freq/(double)(history.freq - cv + history.succ);
+ }
+ }
+ } else {
+ fstar=0;
+ lambda=1;
+ }
+ } else {
+ fstar=unigr(ng);
+ lambda=0;
+ }
+ VERBOSE(3,"linearwb::discount(ngram ng_,int size,double& fstar,double& lambda,int cv) ng_:|" << ng_ << "| returning fstar:" << fstar << " lambda:" << lambda << std::endl);
+ return 1;
+ }
+
+ linearstb::linearstb(char* ngtfile,int depth,int prunefreq,TABLETYPE tt):
+ mdiadaptlm(ngtfile,depth,tt)
+ {
+ prunethresh=prunefreq;
+ cerr << "PruneThresh: " << prunethresh << "\n";
+ };
+
+ int linearstb::train()
+ {
+ trainunigr();
+
+ gensuccstat();
+
+ return 1;
+ }
+
+
+ int linearstb::discount(ngram ng_,int size,double& fstar,double& lambda,int cv)
+ {
+ VERBOSE(3,"linearstb::discount(ngram ng_,int size,double& fstar,double& lambda,int cv) ng_:|" << ng_ << "| size:" << size << " cv:" << cv<< std::endl);
+ ngram ng(dict);
+ ng.trans(ng_);
+
+ lambda = 0.4;
+
+ if (size > 1) {
+ ngram history=ng;
+
+ if (ng.ckhisto(size) && get(history,size,size-1) && (history.freq>cv) &&
+ ((size < 3) || ((history.freq-cv) > prunethresh))) {
+ // apply history pruning on trigrams only
+
+ if (get(ng,size,size) && (!prunesingletons() || ng.freq>1 || size<3)) {
+ // apply frequency pruning on trigrams only
+
+ cv=(ng.freq<cv)?ng.freq:cv; //hence, ng.freq>=cv
+
+ if (ng.freq>cv) {
+ fstar=(double)(ng.freq-cv)/(double)(history.freq - cv);
+ if (size>=3 && prunesingletons()){ // correction due to frequency pruning
+ if (history.freq<=1 && size>3){
+ lambda = 1.0;
+ }
+ }
+ } else { // ng.freq == cv
+ fstar=0.0;
+ if (size>=3 && prunesingletons()){ // correction due to frequency pruning
+ if (history.freq<=1 && history.size>3){
+ lambda = 1.0;
+ }
+ }
+ }
+ } else {
+ fstar=0.0;
+ if (size>=3 && prunesingletons()){ // correction due to frequency pruning
+ if (history.freq<=1 && history.size>3){
+ lambda = 1.0;
+ }
+ }
+ }
+
+ //cerr << "ngram :" << ng << "\n";
+ // if current word is OOV then back-off to unigrams!
+ if (*ng.wordp(1)==dict->oovcode()) {
+ fstar=0.0;
+ }
+ }
+ else {
+ fstar=0;
+ lambda=1;
+ }
+ } else {
+ fstar=unigr(ng);
+ lambda=0;
+ }
+
+ VERBOSE(3,"linearstb::discount(ngram ng_,int size,double& fstar,double& lambda,int cv) ng_:|" << ng_ << "| returning fstar:" << fstar << " lambda:" << lambda << std::endl);
+ return 1;
+ }
+
+ int linearstb::compute_backoff()
+ {
+ VERBOSE(3,"linearstb::compute_backoff() ... ");
+
+ this->backoff=1;
+
+ for (int size=1; size<lmsize(); size++) {
+
+ ngram hg(dict,size);
+
+ scan(hg,INIT,size);
+ while(scan(hg,CONT,size)) {
+ boff(hg.link,1.0);
+ }
+ }
+
+ VERBOSE(3,"linearstb::compute_backoff() COMPLETED\n");
+ return 1;
+ }
+
+}//namespace irstlm
diff --git a/src/linearlm.h b/src/linearlm.h
new file mode 100644
index 0000000..3079a0b
--- /dev/null
+++ b/src/linearlm.h
@@ -0,0 +1,55 @@
+/******************************************************************************
+ IrstLM: IRST Language Model Toolkit
+ Copyright (C) 2006 Marcello Federico, ITC-irst Trento, Italy
+
+ This library is free software; you can redistribute it and/or
+ modify it under the terms of the GNU Lesser General Public
+ License as published by the Free Software Foundation; either
+ version 2.1 of the License, or (at your option) any later version.
+
+ This library is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+ Lesser General Public License for more details.
+
+ You should have received a copy of the GNU Lesser General Public
+ License along with this library; if not, write to the Free Software
+ Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+
+ ******************************************************************************/
+
+
+// Linear discounting interpolated LMs
+
+
+
+namespace irstlm {
+ //Witten and Bell linear discounting
+ class linearwb: public mdiadaptlm
+ {
+ int prunethresh;
+ int minfreqthresh;
+ public:
+ linearwb(char* ngtfile,int depth=0,int prunefreq=0,TABLETYPE tt=SHIFTBETA_B);
+ int train();
+ int discount(ngram ng,int size,double& fstar,double& lambda,int cv=0);
+ ~linearwb() {}
+ };
+
+ //Stupid-Backoff LM type
+ class linearstb: public mdiadaptlm
+ {
+ int prunethresh;
+ int minfreqthresh;
+ public:
+ linearstb(char* ngtfile,int depth=0,int prunefreq=0,TABLETYPE tt=SHIFTBETA_B);
+ int train();
+ int discount(ngram ng,int size,double& fstar,double& lambda,int cv=0);
+ ~linearstb() {}
+ int compute_backoff();
+ };
+
+ //Good Turing linear discounting
+ //no more supported
+
+}//namespace irstlm
\ No newline at end of file
diff --git a/src/lmContainer.cpp b/src/lmContainer.cpp
new file mode 100644
index 0000000..4d7a756
--- /dev/null
+++ b/src/lmContainer.cpp
@@ -0,0 +1,167 @@
+// $Id: lmContainer.cpp 3686 2010-10-15 11:55:32Z bertoldi $
+
+/******************************************************************************
+IrstLM: IRST Language Model Toolkit
+Copyright (C) 2006 Marcello Federico, ITC-irst Trento, Italy
+
+This library is free software; you can redistribute it and/or
+modify it under the terms of the GNU Lesser General Public
+License as published by the Free Software Foundation; either
+version 2.1 of the License, or (at your option) any later version.
+
+This library is distributed in the hope that it will be useful,
+but WITHOUT ANY WARRANTY; without even the implied warranty of
+MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+Lesser General Public License for more details.
+
+You should have received a copy of the GNU Lesser General Public
+License along with this library; if not, write to the Free Software
+Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+
+******************************************************************************/
+#include <stdio.h>
+#include <cstdlib>
+#include <stdlib.h>
+#include <iostream>
+#include <stdexcept>
+#include <string>
+#include <sstream>
+#include "util.h"
+#include "lmContainer.h"
+#include "lmtable.h"
+#include "lmmacro.h"
+#include "lmclass.h"
+#include "lmInterpolation.h"
+
+using namespace std;
+
+namespace irstlm {
+
+#ifdef PS_CACHE_ENABLE
+#if PS_CACHE_ENABLE==0
+#undef PS_CACHE_ENABLE
+#endif
+#endif
+
+#ifdef LMT_CACHE_ENABLE
+#if LMT_CACHE_ENABLE==0
+#undef LMT_CACHE_ENABLE
+#endif
+#endif
+
+#if PS_CACHE_ENABLE
+bool lmContainer::ps_cache_enabled=true;
+#else
+bool lmContainer::ps_cache_enabled=false;
+#endif
+
+#if LMT_CACHE_ENABLE
+bool lmContainer::lmt_cache_enabled=true;
+#else
+bool lmContainer::lmt_cache_enabled=false;
+#endif
+
+inline void error(const char* message)
+{
+ std::cerr << message << "\n";
+ throw std::runtime_error(message);
+}
+
+lmContainer::lmContainer()
+{
+ requiredMaxlev=1000;
+ lmtype=_IRSTLM_LMUNKNOWN;
+ maxlev=0;
+}
+
+int lmContainer::getLanguageModelType(std::string filename)
+{
+ fstream inp(filename.c_str(),ios::in|ios::binary);
+
+ if (!inp.good()) {
+ std::stringstream ss_msg;
+ ss_msg << "Failed to open " << filename;
+ exit_error(IRSTLM_ERROR_IO, ss_msg.str());
+ }
+ //give a look at the header to get informed about the language model type
+ std::string header;
+ inp >> header;
+ inp.close();
+
+ VERBOSE(1,"LM header:|" << header << "|" << std::endl);
+
+ int type=_IRSTLM_LMUNKNOWN;
+ VERBOSE(1,"type: " << type << std::endl);
+ if (header == "lmminterpolation" || header == "LMINTERPOLATION") {
+ type = _IRSTLM_LMINTERPOLATION;
+ } else if (header == "lmmacro" || header == "LMMACRO") {
+ type = _IRSTLM_LMMACRO;
+ } else if (header == "lmclass" || header == "LMCLASS") {
+ type = _IRSTLM_LMCLASS;
+ } else {
+ type = _IRSTLM_LMTABLE;
+ }
+ VERBOSE(1,"type: " << type << std::endl);
+
+ return type;
+};
+
+lmContainer* lmContainer::CreateLanguageModel(const std::string infile, float nlf, float dlf)
+{
+ int type = lmContainer::getLanguageModelType(infile);
+ std::cerr << "Language Model Type of " << infile << " is " << type << std::endl;
+
+ return lmContainer::CreateLanguageModel(type, nlf, dlf);
+}
+
+lmContainer* lmContainer::CreateLanguageModel(int type, float nlf, float dlf)
+{
+
+ std::cerr << "Language Model Type is " << type << std::endl;
+
+ lmContainer* lm=NULL;
+
+ switch (type) {
+
+ case _IRSTLM_LMTABLE:
+ lm = new lmtable(nlf, dlf);
+ break;
+
+ case _IRSTLM_LMMACRO:
+ lm = new lmmacro(nlf, dlf);
+ break;
+
+ case _IRSTLM_LMCLASS:
+ lm = new lmclass(nlf, dlf);
+ break;
+
+ case _IRSTLM_LMINTERPOLATION:
+ lm = new lmInterpolation(nlf, dlf);
+ break;
+
+ default:
+ exit_error(IRSTLM_ERROR_DATA, "This language model type is unknown!");
+ }
+
+ lm->setLanguageModelType(type);
+ return lm;
+}
+
+bool lmContainer::filter(const string sfilter, lmContainer*& sublmC, const string skeepunigrams)
+{
+ if (lmtype == _IRSTLM_LMTABLE) {
+ sublmC = lmContainer::CreateLanguageModel(lmtype,((lmtable*) this)->GetNgramcacheLoadFactor(),((lmtable*) this)->GetDictionaryLoadFactor());
+
+ //let know that table has inverted n-grams
+ sublmC->is_inverted(is_inverted());
+ sublmC->setMaxLoadedLevel(getMaxLoadedLevel());
+ sublmC->maxlevel(maxlevel());
+
+ bool res=((lmtable*) this)->filter(sfilter, (lmtable*) sublmC, skeepunigrams);
+
+ return res;
+ }
+ return false;
+};
+
+}//namespace irstlm
diff --git a/src/lmContainer.h b/src/lmContainer.h
new file mode 100644
index 0000000..03c8b37
--- /dev/null
+++ b/src/lmContainer.h
@@ -0,0 +1,198 @@
+// $Id: lmContainer.h 3686 2010-10-15 11:55:32Z bertoldi $
+
+/******************************************************************************
+IrstLM: IRST Language Model Toolkit
+Copyright (C) 2006 Marcello Federico, ITC-irst Trento, Italy
+
+This library is free software; you can redistribute it and/or
+modify it under the terms of the GNU Lesser General Public
+License as published by the Free Software Foundation; either
+version 2.1 of the License, or (at your option) any later version.
+
+This library is distributed in the hope that it will be useful,
+but WITHOUT ANY WARRANTY; without even the implied warranty of
+MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+Lesser General Public License for more details.
+
+You should have received a copy of the GNU Lesser General Public
+License along with this library; if not, write to the Free Software
+Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+
+******************************************************************************/
+
+#ifndef MF_LMCONTAINER_H
+#define MF_LMCONTAINER_H
+
+#define _IRSTLM_LMUNKNOWN 0
+#define _IRSTLM_LMTABLE 1
+#define _IRSTLM_LMMACRO 2
+#define _IRSTLM_LMCLASS 3
+#define _IRSTLM_LMINTERPOLATION 4
+
+
+#include <stdio.h>
+#include <cstdlib>
+#include <stdlib.h>
+#include "util.h"
+#include "n_gram.h"
+#include "dictionary.h"
+
+typedef enum {BINARY,TEXT,YRANIB,NONE} OUTFILE_TYPE;
+
+namespace irstlm {
+class lmContainer
+{
+ static const bool debug=true;
+ static bool ps_cache_enabled;
+ static bool lmt_cache_enabled;
+
+protected:
+ int lmtype; //auto reference to its own type
+ int maxlev; //maximun order of sub LMs;
+ int requiredMaxlev; //max loaded level, i.e. load up to requiredMaxlev levels
+
+public:
+
+ lmContainer();
+ virtual ~lmContainer() {};
+
+
+ virtual void load(const std::string &filename, int mmap=0) {
+ UNUSED(filename);
+ UNUSED(mmap);
+ };
+
+ virtual void savetxt(const char *filename) {
+ UNUSED(filename);
+ };
+ virtual void savebin(const char *filename) {
+ UNUSED(filename);
+ };
+
+ virtual double getlogOOVpenalty() const {
+ return 0.0;
+ };
+ virtual double setlogOOVpenalty(int dub) {
+ UNUSED(dub);
+ return 0.0;
+ };
+ virtual double setlogOOVpenalty(double oovp) {
+ UNUSED(oovp);
+ return 0.0;
+ };
+
+ inline virtual dictionary* getDict() const {
+ return NULL;
+ };
+ inline virtual void maxlevel(int lev) {
+ maxlev = lev;
+ };
+ inline virtual int maxlevel() const {
+ return maxlev;
+ };
+ inline virtual void stat(int lev=0) {
+ UNUSED(lev);
+ };
+
+ inline virtual void setMaxLoadedLevel(int lev) {
+ requiredMaxlev=lev;
+ };
+ inline virtual int getMaxLoadedLevel() {
+ return requiredMaxlev;
+ };
+
+ virtual bool is_inverted(const bool flag) {
+ UNUSED(flag);
+ return false;
+ };
+ virtual bool is_inverted() {
+ return false;
+ };
+ virtual double clprob(ngram ng, double* bow=NULL, int* bol=NULL, char** maxsuffptr=NULL, unsigned int* statesize=NULL,bool* extendible=NULL) {
+ UNUSED(ng);
+ UNUSED(bow);
+ UNUSED(bol);
+ UNUSED(maxsuffptr);
+ UNUSED(statesize);
+ UNUSED(extendible);
+ return 0.0;
+ };
+ virtual double clprob(int* ng, int ngsize, double* bow=NULL, int* bol=NULL, char** maxsuffptr=NULL, unsigned int* statesize=NULL,bool* extendible=NULL) {
+ UNUSED(ng);
+ UNUSED(ngsize);
+ UNUSED(bow);
+ UNUSED(bol);
+ UNUSED(maxsuffptr);
+ UNUSED(statesize);
+ UNUSED(extendible);
+ return 0.0;
+ };
+
+
+ virtual const char *cmaxsuffptr(ngram ng, unsigned int* statesize=NULL)
+ {
+ UNUSED(ng);
+ UNUSED(statesize);
+ return NULL;
+ }
+
+ virtual const char *cmaxsuffptr(int* ng, int ngsize, unsigned int* statesize=NULL)
+ {
+ UNUSED(ng);
+ UNUSED(ngsize);
+ UNUSED(statesize);
+ return NULL;
+ }
+
+ virtual void used_caches() {};
+ virtual void init_caches(int uptolev) {
+ UNUSED(uptolev);
+ };
+ virtual void check_caches_levels() {};
+ virtual void reset_caches() {};
+
+ virtual void reset_mmap() {};
+
+ void inline setLanguageModelType(int type) {
+ lmtype=type;
+ };
+ int getLanguageModelType() const {
+ return lmtype;
+ };
+ static int getLanguageModelType(std::string filename);
+
+ inline virtual void dictionary_incflag(const bool flag) {
+ UNUSED(flag);
+ };
+
+ virtual bool filter(const string sfilter, lmContainer*& sublmt, const string skeepunigrams);
+
+ static lmContainer* CreateLanguageModel(const std::string infile, float nlf=0.0, float dlf=0.0);
+ static lmContainer* CreateLanguageModel(int type, float nlf=0.0, float dlf=0.0);
+
+ inline virtual bool is_OOV(int code) {
+ UNUSED(code);
+ return false;
+ };
+
+
+ inline static bool is_lmt_cache_enabled(){
+ VERBOSE(3,"inline static bool is_lmt_cache_enabled() " << lmt_cache_enabled << std::endl);
+ return lmt_cache_enabled;
+ }
+
+ inline static bool is_ps_cache_enabled(){
+ VERBOSE(3,"inline static bool is_ps_cache_enabled() " << ps_cache_enabled << std::endl);
+ return ps_cache_enabled;
+ }
+
+ inline static bool is_cache_enabled(){
+ return is_lmt_cache_enabled() && is_ps_cache_enabled();
+ }
+
+};
+
+}//namespace irstlm
+
+#endif
+
diff --git a/src/lmInterpolation.cpp b/src/lmInterpolation.cpp
new file mode 100644
index 0000000..de66a96
--- /dev/null
+++ b/src/lmInterpolation.cpp
@@ -0,0 +1,243 @@
+// $Id: lmInterpolation.cpp 3686 2010-10-15 11:55:32Z bertoldi $
+
+/******************************************************************************
+ IrstLM: IRST Language Model Toolkit
+ Copyright (C) 2006 Marcello Federico, ITC-irst Trento, Italy
+
+ This library is free software; you can redistribute it and/or
+ modify it under the terms of the GNU Lesser General Public
+ License as published by the Free Software Foundation; either
+ version 2.1 of the License, or (at your option) any later version.
+
+ This library is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+ Lesser General Public License for more details.
+
+ You should have received a copy of the GNU Lesser General Public
+ License along with this library; if not, write to the Free Software
+ Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+
+ ******************************************************************************/
+#include <stdio.h>
+#include <cstdlib>
+#include <stdlib.h>
+#include <iostream>
+#include <stdexcept>
+#include <string>
+#include "lmContainer.h"
+#include "lmInterpolation.h"
+#include "util.h"
+
+using namespace std;
+
+inline void error(const char* message)
+{
+ std::cerr << message << "\n";
+ throw std::runtime_error(message);
+}
+
+namespace irstlm {
+lmInterpolation::lmInterpolation(float nlf, float dlf)
+{
+ ngramcache_load_factor = nlf;
+ dictionary_load_factor = dlf;
+
+ order=0;
+ memmap=0;
+ isInverted=false;
+}
+
+void lmInterpolation::load(const std::string &filename,int mmap)
+{
+ VERBOSE(2,"lmInterpolation::load(const std::string &filename,int memmap)" << std::endl);
+ VERBOSE(2," filename:|" << filename << "|" << std::endl);
+
+
+ dictionary_upperbound=1000000;
+ int memmap=mmap;
+
+
+ dict=new dictionary((char *)NULL,1000000,dictionary_load_factor);
+
+ //get info from the configuration file
+ fstream inp(filename.c_str(),ios::in|ios::binary);
+
+ char line[MAX_LINE];
+ const char* words[LMINTERPOLATION_MAX_TOKEN];
+ int tokenN;
+ inp.getline(line,MAX_LINE,'\n');
+ tokenN = parseWords(line,words,LMINTERPOLATION_MAX_TOKEN);
+
+ if (tokenN != 2 || ((strcmp(words[0],"LMINTERPOLATION") != 0) && (strcmp(words[0],"lminterpolation")!=0)))
+ error((char*)"ERROR: wrong header format of configuration file\ncorrect format: LMINTERPOLATION number_of_models\nweight_of_LM_1 filename_of_LM_1\nweight_of_LM_2 filename_of_LM_2");
+
+ m_number_lm = atoi(words[1]);
+
+ m_weight.resize(m_number_lm);
+ m_file.resize(m_number_lm);
+ m_isinverted.resize(m_number_lm);
+ m_lm.resize(m_number_lm);
+
+ VERBOSE(2,"lmInterpolation::load(const std::string &filename,int mmap) m_number_lm:"<< m_number_lm << std::endl;);
+
+ dict->incflag(1);
+ for (int i=0; i<m_number_lm; i++) {
+ inp.getline(line,BUFSIZ,'\n');
+ tokenN = parseWords(line,words,3);
+
+ if(tokenN < 2 || tokenN >3) {
+ error((char*)"ERROR: wrong header format of configuration file\ncorrect format: LMINTERPOLATION number_of_models\nweight_of_LM_1 filename_of_LM_1\nweight_of_LM_2 filename_of_LM_2");
+ }
+
+ //check whether the (textual) LM has to be loaded as inverted
+ m_isinverted[i] = false;
+ if(tokenN == 3) {
+ if (strcmp(words[2],"inverted") == 0)
+ m_isinverted[i] = true;
+ }
+ VERBOSE(2,"i:" << i << " m_isinverted[i]:" << m_isinverted[i] << endl);
+
+ m_weight[i] = (float) atof(words[0]);
+ m_file[i] = words[1];
+ VERBOSE(2,"lmInterpolation::load(const std::string &filename,int mmap) m_file:"<< words[1] << std::endl;);
+
+ m_lm[i] = load_lm(i,memmap,ngramcache_load_factor,dictionary_load_factor);
+ //set the actual value for inverted flag, which is known only after loading the lM
+ m_isinverted[i] = m_lm[i]->is_inverted();
+
+ dictionary *_dict=m_lm[i]->getDict();
+ for (int j=0; j<_dict->size(); j++) {
+ dict->encode(_dict->decode(j));
+ }
+ }
+ getDict()->genoovcode();
+
+ getDict()->incflag(1);
+ inp.close();
+
+ int maxorder = 0;
+ for (int i=0; i<m_number_lm; i++) {
+ maxorder = (maxorder > m_lm[i]->maxlevel())?maxorder:m_lm[i]->maxlevel();
+ }
+
+ if (order == 0) {
+ order = maxorder;
+ std::cerr << "order is not set; reset to the maximum order of LMs: " << order << std::endl;
+ } else if (order > maxorder) {
+ order = maxorder;
+ std::cerr << "order is too high; reset to the maximum order of LMs: " << order << std::endl;
+ }
+ maxlev=order;
+}
+
+lmContainer* lmInterpolation::load_lm(int i,int memmap, float nlf, float dlf)
+{
+ //checking the language model type
+ lmContainer* lmt=lmContainer::CreateLanguageModel(m_file[i],nlf,dlf);
+
+ //let know that table has inverted n-grams
+ lmt->is_inverted(m_isinverted[i]); //set inverted flag for each LM
+
+ lmt->setMaxLoadedLevel(requiredMaxlev);
+
+ lmt->load(m_file[i], memmap);
+
+ lmt->init_caches(lmt->maxlevel());
+ return lmt;
+}
+
+
+double lmInterpolation::clprob(ngram ng, double* bow,int* bol,char** maxsuffptr,unsigned int* statesize,bool* extendible)
+{
+
+ double pr=0.0;
+ double _logpr;
+
+ char* _maxsuffptr=NULL,*actualmaxsuffptr=NULL;
+ unsigned int _statesize=0,actualstatesize=0;
+ int _bol=0,actualbol=MAX_NGRAM;
+ double _bow=0.0,actualbow=0.0;
+ bool _extendible=false;
+ bool actualextendible=false;
+
+ for (size_t i=0; i<m_lm.size(); i++) {
+
+ ngram _ng(m_lm[i]->getDict());
+ _ng.trans(ng);
+ _logpr=m_lm[i]->clprob(_ng,&_bow,&_bol,&_maxsuffptr,&_statesize,&_extendible);
+
+ /*
+ cerr.precision(10);
+ std::cerr << " LM " << i << " weight:" << m_weight[i] << std::endl;
+ std::cerr << " LM " << i << " log10 logpr:" << _logpr<< std::endl;
+ std::cerr << " LM " << i << " pr:" << pow(10.0,_logpr) << std::endl;
+ std::cerr << " _statesize:" << _statesize << std::endl;
+ std::cerr << " _bow:" << _bow << std::endl;
+ std::cerr << " _bol:" << _bol << std::endl;
+ */
+
+ //TO CHECK the following claims
+ //What is the statesize of a LM interpolation? The largest _statesize among the submodels
+ //What is the maxsuffptr of a LM interpolation? The _maxsuffptr of the submodel with the largest _statesize
+ //What is the bol of a LM interpolation? The smallest _bol among the submodels
+ //What is the bow of a LM interpolation? The weighted sum of the bow of the submodels
+ //What is the prob of a LM interpolation? The weighted sum of the prob of the submodels
+ //What is the extendible flag of a LM interpolation? true if the extendible flag is one for any LM
+
+ pr+=m_weight[i]*pow(10.0,_logpr);
+ actualbow+=m_weight[i]*pow(10.0,_bow);
+
+ if(_statesize > actualstatesize || i == 0) {
+ actualmaxsuffptr = _maxsuffptr;
+ actualstatesize = _statesize;
+ }
+ if (_bol < actualbol) {
+ actualbol=_bol; //backoff limit of LM[i]
+ }
+ if (_extendible) {
+ actualextendible=true; //set extendible flag to true if the ngram is extendible for any LM
+ }
+ }
+ if (bol) *bol=actualbol;
+ if (bow) *bow=log(actualbow);
+ if (maxsuffptr) *maxsuffptr=actualmaxsuffptr;
+ if (statesize) *statesize=actualstatesize;
+ if (extendible) {
+ *extendible=actualextendible;
+ // delete _extendible;
+ }
+
+ /*
+ if (statesize) std::cerr << " statesize:" << *statesize << std::endl;
+ if (bow) std::cerr << " bow:" << *bow << std::endl;
+ if (bol) std::cerr << " bol:" << *bol << std::endl;
+ */
+ return log(pr)/M_LN10;
+}
+
+double lmInterpolation::clprob(int* codes, int sz, double* bow,int* bol,char** maxsuffptr,unsigned int* statesize,bool* extendible)
+{
+
+ //create the actual ngram
+ ngram ong(dict);
+ ong.pushc(codes,sz);
+ MY_ASSERT (ong.size == sz);
+
+ return clprob(ong, bow, bol, maxsuffptr, statesize, extendible);
+}
+
+double lmInterpolation::setlogOOVpenalty(int dub)
+{
+ MY_ASSERT(dub > dict->size());
+ double _logpr;
+ double OOVpenalty=0.0;
+ for (int i=0; i<m_number_lm; i++) {
+ m_lm[i]->setlogOOVpenalty(dub); //set OOV Penalty for each LM
+ _logpr=m_lm[i]->getlogOOVpenalty();
+ OOVpenalty+=m_weight[i]*exp(_logpr);
+ }
+ logOOVpenalty=log(OOVpenalty);
+ return logOOVpenalty;
+}
+}//namespace irstlm
diff --git a/src/lmInterpolation.h b/src/lmInterpolation.h
new file mode 100644
index 0000000..f5d2627
--- /dev/null
+++ b/src/lmInterpolation.h
@@ -0,0 +1,131 @@
+// $Id: lmInterpolation.h 3686 2010-10-15 11:55:32Z bertoldi $
+
+/******************************************************************************
+IrstLM: IRST Language Model Toolkit
+Copyright (C) 2006 Marcello Federico, ITC-irst Trento, Italy
+
+This library is free software; you can redistribute it and/or
+modify it under the terms of the GNU Lesser General Public
+License as published by the Free Software Foundation; either
+version 2.1 of the License, or (at your option) any later version.
+
+This library is distributed in the hope that it will be useful,
+but WITHOUT ANY WARRANTY; without even the implied warranty of
+MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+Lesser General Public License for more details.
+
+You should have received a copy of the GNU Lesser General Public
+License along with this library; if not, write to the Free Software
+Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+
+******************************************************************************/
+
+#ifndef MF_LMINTERPOLATION_H
+#define MF_LMINTERPOLATION_H
+
+#include <stdio.h>
+#include <cstdlib>
+#include <stdlib.h>
+#include <string>
+#include <math.h>
+#include <vector>
+#include "util.h"
+#include "dictionary.h"
+#include "n_gram.h"
+#include "lmContainer.h"
+
+
+namespace irstlm {
+/*
+interpolation of several sub LMs
+*/
+
+#define LMINTERPOLATION_MAX_TOKEN 3
+
+class lmInterpolation: public lmContainer
+{
+ static const bool debug=true;
+ int m_number_lm;
+ int order;
+ int dictionary_upperbound; //set by user
+ double logOOVpenalty; //penalty for OOV words (default 0)
+ bool isInverted;
+ int memmap; //level from which n-grams are accessed via mmap
+
+ std::vector<double> m_weight;
+ std::vector<std::string> m_file;
+ std::vector<bool> m_isinverted;
+ std::vector<lmContainer*> m_lm;
+
+ int maxlev; //maximun order of sub LMs;
+
+ float ngramcache_load_factor;
+ float dictionary_load_factor;
+
+ dictionary *dict; // dictionary for all interpolated LMs
+
+public:
+
+ lmInterpolation(float nlf=0.0, float dlfi=0.0);
+ virtual ~lmInterpolation() {};
+
+ void load(const std::string &filename,int mmap=0);
+ lmContainer* load_lm(int i, int memmap, float nlf, float dlf);
+
+ virtual double clprob(ngram ng, double* bow=NULL,int* bol=NULL,char** maxsuffptr=NULL,unsigned int* statesize=NULL,bool* extendible=NULL);
+ virtual double clprob(int* ng, int ngsize, double* bow=NULL,int* bol=NULL,char** maxsuffptr=NULL,unsigned int* statesize=NULL,bool* extendible=NULL);
+
+ int maxlevel() const {
+ return maxlev;
+ };
+
+ virtual inline void setDict(dictionary* d) {
+ if (dict) delete dict;
+ dict=d;
+ };
+
+ virtual inline dictionary* getDict() const {
+ return dict;
+ };
+
+ //set penalty for OOV words
+ virtual inline double getlogOOVpenalty() const {
+ return logOOVpenalty;
+ }
+
+ virtual double setlogOOVpenalty(int dub);
+
+ double inline setlogOOVpenalty(double oovp) {
+ return logOOVpenalty=oovp;
+ }
+
+//set the inverted flag (used to set the inverted flag of each subLM, when loading)
+ inline bool is_inverted(const bool flag) {
+ return isInverted = flag;
+ }
+
+//for an interpolation LM this variable does not make sense
+//for compatibility, we return true if all subLM return true
+ inline bool is_inverted() {
+ for (int i=0; i<m_number_lm; i++) {
+ if (m_isinverted[i] == false) return false;
+ }
+ return true;
+ }
+
+ inline virtual void dictionary_incflag(const bool flag) {
+ dict->incflag(flag);
+ };
+
+ inline virtual bool is_OOV(int code) { //returns true if the word is OOV for each subLM
+ for (int i=0; i<m_number_lm; i++) {
+ int _code=m_lm[i]->getDict()->encode(getDict()->decode(code));
+ if (m_lm[i]->is_OOV(_code) == false) return false;
+ }
+ return true;
+ }
+};
+}//namespace irstlm
+
+#endif
+
diff --git a/src/lmclass.cpp b/src/lmclass.cpp
new file mode 100644
index 0000000..75626b3
--- /dev/null
+++ b/src/lmclass.cpp
@@ -0,0 +1,236 @@
+// $Id: lmclass.cpp 3631 2010-10-07 12:04:12Z bertoldi $
+
+/******************************************************************************
+IrstLM: IRST Language Model Toolkit
+Copyright (C) 2006 Marcello Federico, ITC-irst Trento, Italy
+
+This library is free software; you can redistribute it and/or
+modify it under the terms of the GNU Lesser General Public
+License as published by the Free Software Foundation; either
+version 2.1 of the License, or (at your option) any later version.
+
+This library is distributed in the hope that it will be useful,
+but WITHOUT ANY WARRANTY; without even the implied warranty of
+MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+Lesser General Public License for more details.
+
+You should have received a copy of the GNU Lesser General Public
+License along with this library; if not, write to the Free Software
+Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+
+******************************************************************************/
+#include <stdio.h>
+#include <stdlib.h>
+#include <fcntl.h>
+#include <iostream>
+#include <fstream>
+#include <stdexcept>
+#include "math.h"
+#include "mempool.h"
+#include "htable.h"
+#include "ngramcache.h"
+#include "dictionary.h"
+#include "n_gram.h"
+#include "lmclass.h"
+#include "util.h"
+
+using namespace std;
+
+// local utilities: start
+
+int parseWords(char *sentence, const char **words, int max);
+
+inline void error(const char* message)
+{
+ cerr << message << "\n";
+ throw runtime_error(message);
+}
+
+// local utilities: end
+
+namespace irstlm {
+
+lmclass::lmclass(float nlf, float dlfi):lmtable(nlf,dlfi)
+{
+ MaxMapSize=1000000;
+ MapScore= (double *)malloc(MaxMapSize*sizeof(double));// //array of probabilities
+ memset(MapScore,0,MaxMapSize*sizeof(double));
+ MapScoreN=0;
+ dict = new dictionary((char *)NULL,MaxMapSize); //word to cluster dictionary
+};
+
+lmclass::~lmclass()
+{
+ free (MapScore);
+ delete dict;
+}
+
+void lmclass::load(const std::string &filename,int memmap)
+{
+ VERBOSE(2,"lmclass::load(const std::string &filename,int memmap)" << std::endl);
+
+ //get info from the configuration file
+ fstream inp(filename.c_str(),ios::in|ios::binary);
+
+ char line[MAX_LINE];
+ const char* words[LMCLASS_MAX_TOKEN];
+ int tokenN;
+ inp.getline(line,MAX_LINE,'\n');
+ tokenN = parseWords(line,words,LMCLASS_MAX_TOKEN);
+
+ if (tokenN != 2 || ((strcmp(words[0],"LMCLASS") != 0) && (strcmp(words[0],"lmclass")!=0)))
+ error((char*)"ERROR: wrong header format of configuration file\ncorrect format: LMCLASS LM_order\nfilename_of_LM\nfilename_of_map");
+
+ maxlev = atoi(words[1]);
+ std::string lmfilename;
+ if (inp.getline(line,MAX_LINE,'\n')) {
+ tokenN = parseWords(line,words,LMCLASS_MAX_TOKEN);
+ lmfilename = words[0];
+ } else {
+ error((char*)"ERROR: wrong header format of configuration file\ncorrect format: LMCLASS LM_order\nfilename_of_LM\nfilename_of_map");
+ }
+
+ std::string W2Cdict = "";
+ if (inp.getline(line,MAX_LINE,'\n')) {
+ tokenN = parseWords(line,words,LMCLASS_MAX_TOKEN);
+ W2Cdict = words[0];
+ } else {
+ error((char*)"ERROR: wrong header format of configuration file\ncorrect format: LMCLASS LM_order\nfilename_of_LM\nfilename_of_map");
+ }
+ inp.close();
+
+ std::cerr << "lmfilename:" << lmfilename << std::endl;
+ if (W2Cdict != "") {
+ std::cerr << "mapfilename:" << W2Cdict << std::endl;
+ } else {
+ error((char*)"ERROR: you must specify a map!");
+ }
+
+
+ // Load the (possibly binary) LM
+ inputfilestream inpLM(lmfilename.c_str());
+ if (!inpLM.good()) {
+ std::cerr << "Failed to open " << lmfilename << "!" << std::endl;
+ exit(1);
+ }
+ lmtable::load(inpLM,lmfilename.c_str(),NULL,memmap);
+
+ inputfilestream inW2C(W2Cdict);
+ if (!inW2C.good()) {
+ std::cerr << "Failed to open " << W2Cdict << "!" << std::endl;
+ exit(1);
+ }
+ loadMap(inW2C);
+ getDict()->genoovcode();
+
+ VERBOSE(2,"OOV code of lmclass is " << getDict()->oovcode() << " mapped into " << getMap(getDict()->oovcode())<< "\n");
+ getDict()->incflag(1);
+}
+
+void lmclass::loadMap(istream& inW2C)
+{
+
+ double lprob=0.0;
+ int howmany=0;
+
+ const char* words[1 + LMTMAXLEV + 1 + 1];
+
+ //open input stream and prepare an input string
+ char line[MAX_LINE];
+
+ dict->incflag(1); //can add to the map dictionary
+
+ cerr<<"loadW2Cdict()...\n";
+ //save freq of EOS and BOS
+
+ loadMapElement(dict->BoS(),lmtable::dict->BoS(),0.0);
+ loadMapElement(dict->EoS(),lmtable::dict->EoS(),0.0);
+
+ //should i add <unk> to the dict or just let the trans_freq handle <unk>
+ loadMapElement(dict->OOV(),lmtable::dict->OOV(),0.0);
+
+ while (inW2C.getline(line,MAX_LINE)) {
+ if (strlen(line)==MAX_LINE-1) {
+ cerr << "lmtable::loadW2Cdict: input line exceed MAXLINE ("
+ << MAX_LINE << ") chars " << line << "\n";
+ exit(1);
+ }
+
+ howmany = parseWords(line, words, 4); //3
+
+ if(howmany == 3) {
+ MY_ASSERT(sscanf(words[2], "%lf", &lprob));
+ lprob=(double)log10(lprob);
+ } else if(howmany==2) {
+
+ VERBOSE(3,"No score for the pair (" << words[0] << "," << words[1] << "); set to default 1.0\n");
+
+ lprob=0.0;
+ } else {
+ cerr << "parseline: not enough entries" << line << "\n";
+ exit(1);
+ }
+ loadMapElement(words[0],words[1],lprob);
+
+ //check if the are available position in MapScore
+ checkMap();
+ }
+
+ VERBOSE(2,"There are " << MapScoreN << " entries in the map\n");
+
+ dict->incflag(0); //can NOT add to the dictionary of lmclass
+}
+
+void lmclass::checkMap()
+{
+ if (MapScoreN > MaxMapSize) {
+ MaxMapSize=2*MapScoreN;
+ MapScore = (double*) reallocf(MapScore, sizeof(double)*(MaxMapSize));
+ VERBOSE(2,"In lmclass::checkMap(...) MaxMapSize=" << MaxMapSize << " MapScoreN=" << MapScoreN << "\n");
+ }
+}
+
+void lmclass::loadMapElement(const char* in, const char* out, double sc)
+{
+ //freq of word (in) encodes the ID of the class (out)
+ //save the probability associated with the pair (in,out)
+ int wcode=dict->encode(in);
+ dict->freq(wcode,lmtable::dict->encode(out));
+ MapScore[wcode]=sc;
+ VERBOSE(3,"In lmclass::loadMapElement(...) in=" << in << " wcode=" << wcode << " out=" << out << " ccode=" << lmtable::dict->encode(out) << " MapScoreN=" << MapScoreN << "\n");
+
+ if (wcode >= MapScoreN) MapScoreN++; //increment size of the array MapScore if the element is new
+}
+
+double lmclass::lprob(ngram ong,double* bow, int* bol, char** maxsuffptr,unsigned int* statesize,bool* extendible)
+{
+ double lpr=getMapScore(*ong.wordp(1));
+
+ VERBOSE(3,"In lmclass::lprob(...) Mapscore = " << lpr << "\n");
+
+ //convert ong to it's clustered encoding
+ ngram mapped_ng(lmtable::getDict());
+ // mapped_ng.trans_freq(ong);
+ mapping(ong,mapped_ng);
+
+ lpr+=lmtable::clprob(mapped_ng,bow,bol,maxsuffptr,statesize, extendible);
+
+ VERBOSE(3,"In lmclass::lprob(...) global prob = " << lpr << "\n");
+ return lpr;
+}
+
+void lmclass::mapping(ngram &in, ngram &out)
+{
+ int insize = in.size;
+ VERBOSE(3,"In lmclass::mapping(ngram &in, ngram &out) in = " << in << "\n");
+
+ // map the input sequence (in) into the corresponding output sequence (out), by applying the provided map
+ for (int i=insize; i>0; i--) {
+ out.pushc(getMap(*in.wordp(i)));
+ }
+
+ VERBOSE(3,"In lmclass::mapping(ngram &in, ngram &out) out = " << out << "\n");
+ return;
+}
+}//namespace irstlm
+
diff --git a/src/lmclass.h b/src/lmclass.h
new file mode 100644
index 0000000..408291d
--- /dev/null
+++ b/src/lmclass.h
@@ -0,0 +1,104 @@
+// $Id: lmclass.h 3461 2010-08-27 10:17:34Z bertoldi $
+
+/******************************************************************************
+IrstLM: IRST Language Model Toolkit
+Copyright (C) 2006 Marcello Federico, ITC-irst Trento, Italy
+
+This library is free software; you can redistribute it and/or
+modify it under the terms of the GNU Lesser General Public
+License as published by the Free Software Foundation; either
+version 2.1 of the License, or (at your option) any later version.
+
+This library is distributed in the hope that it will be useful,
+but WITHOUT ANY WARRANTY; without even the implied warranty of
+MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+Lesser General Public License for more details.
+
+You should have received a copy of the GNU Lesser General Public
+License along with this library; if not, write to the Free Software
+Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+
+******************************************************************************/
+
+
+#ifndef MF_LMCLASS_H
+#define MF_LMCLASS_H
+
+#ifndef WIN32
+#include <sys/types.h>
+#include <sys/mman.h>
+#endif
+
+#include "util.h"
+#include "ngramcache.h"
+#include "dictionary.h"
+#include "n_gram.h"
+#include "lmtable.h"
+
+#define LMCLASS_MAX_TOKEN 2
+
+namespace irstlm {
+class lmclass: public lmtable
+{
+ dictionary *dict; // dictionary (words - macro tags)
+ double *MapScore;
+ int MapScoreN;
+ int MaxMapSize;
+
+protected:
+ void loadMap(std::istream& inp);
+ void loadMapElement(const char* in, const char* out, double sc);
+ void mapping(ngram &in, ngram &out);
+
+ inline double getMapScore(int wcode) {
+//the input word is un-known by the map, so I "transform" this word into the oov (of the words)
+ if (wcode >= MapScoreN) {
+ wcode = getDict()->oovcode();
+ }
+ return MapScore[wcode];
+ };
+
+ inline size_t getMap(int wcode) {
+//the input word is un-known by the map, so I "transform" this word into the oov (of the words)
+ if (wcode >= MapScoreN) {
+ wcode = getDict()->oovcode();
+ }
+ return dict->freq(wcode);
+ };
+
+ void checkMap();
+
+public:
+ lmclass(float nlf=0.0, float dlfi=0.0);
+
+ ~lmclass();
+
+ void load(const std::string &filename,int mmap=0);
+
+ double lprob(ngram ng, double* bow=NULL,int* bol=NULL,char** maxsuffptr=NULL,unsigned int* statesize=NULL,bool* extendible=NULL);
+ inline double clprob(ngram ng,double* bow=NULL,int* bol=NULL,char** maxsuffptr=NULL,unsigned int* statesize=NULL,bool* extendible=NULL) {
+ return lprob(ng,bow,bol,maxsuffptr,statesize,extendible);
+ };
+ inline double clprob(int* ng, int ngsize, double* bow=NULL,int* bol=NULL,char** maxsuffptr=NULL,unsigned int* statesize=NULL,bool* extendible=NULL) {
+ ngram ong(getDict());
+ ong.pushc(ng,ngsize);
+ return lprob(ong,bow,bol,maxsuffptr,statesize,extendible);
+ };
+
+ inline bool is_OOV(int code) {
+ //a word is consisdered OOV if its mapped value is OOV
+ return lmtable::is_OOV(getMap(code));
+ };
+
+ inline dictionary* getDict() const {
+ return dict;
+ }
+ inline virtual void dictionary_incflag(const bool flag) {
+ dict->incflag(flag);
+ };
+};
+
+}//namespace irstlm
+
+#endif
+
diff --git a/src/lmmacro.cpp b/src/lmmacro.cpp
new file mode 100644
index 0000000..2d8f482
--- /dev/null
+++ b/src/lmmacro.cpp
@@ -0,0 +1,903 @@
+// $Id: lmmacro.cpp 3631 2010-10-07 12:04:12Z bertoldi $
+
+/******************************************************************************
+IrstLM: IRST Language Model Toolkit
+Copyright (C) 2006 Marcello Federico, ITC-irst Trento, Italy
+
+This library is free software; you can redistribute it and/or
+modify it under the terms of the GNU Lesser General Public
+License as published by the Free Software Foundation; either
+version 2.1 of the License, or (at your option) any later version.
+
+This library is distributed in the hope that it will be useful,
+but WITHOUT ANY WARRANTY; without even the implied warranty of
+MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+Lesser General Public License for more details.
+
+You should have received a copy of the GNU Lesser General Public
+License along with this library; if not, write to the Free Software
+Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+
+******************************************************************************/
+#include <stdio.h>
+#include <stdlib.h>
+#include <fcntl.h>
+#include <iostream>
+#include <fstream>
+#include <stdexcept>
+#include "math.h"
+#include "mempool.h"
+#include "htable.h"
+#include "ngramcache.h"
+#include "dictionary.h"
+#include "n_gram.h"
+#include "lmtable.h"
+#include "lmmacro.h"
+#include "util.h"
+
+using namespace std;
+
+// local utilities: start
+
+inline void error(const char* message)
+{
+ cerr << message << "\n";
+ throw runtime_error(message);
+}
+
+// local utilities: end
+
+
+namespace irstlm {
+
+lmmacro::lmmacro(float nlf, float dlfi):lmtable(nlf,dlfi)
+{
+ dict = new dictionary((char *)NULL,1000000); // dict of micro tags
+ getDict()->incflag(1);
+};
+
+lmmacro::~lmmacro()
+{
+ if (mapFlag) unloadmap();
+}
+
+
+void lmmacro::load(const std::string &filename,int memmap)
+{
+ VERBOSE(2,"lmmacro::load(const std::string &filename,int memmap)" << std::endl);
+
+ //get info from the configuration file
+ fstream inp(filename.c_str(),ios::in|ios::binary);
+
+ char line[MAX_LINE];
+ const char* words[MAX_TOKEN_N_MAP];
+ int tokenN;
+ inp.getline(line,MAX_LINE,'\n');
+ tokenN = parseWords(line,words,MAX_TOKEN_N_MAP);
+
+ if (tokenN != 4 || ((strcmp(words[0],"LMMACRO") != 0) && (strcmp(words[0],"lmmacro")!=0)))
+ error((char*)"ERROR: wrong header format of configuration file\ncorrect format: LMMACRO lmsize field [true|false]\nfilename_of_LM\nfilename_of_map (optional)");
+ maxlev = atoi(words[1]);
+ selectedField = atoi(words[2]);
+
+ if ((strcmp(words[3],"TRUE") == 0) || (strcmp(words[3],"true") == 0))
+ collapseFlag = true;
+ else if ((strcmp(words[3],"FALSE") == 0) || (strcmp(words[3],"false") == 0))
+ collapseFlag = false;
+ else
+ error((char*)"ERROR: wrong header format of configuration file\ncorrect format: LMMACRO lmsize field [true|false]\nfilename_of_LM\nfilename_of_map (optional)");
+
+#ifdef DLEXICALLM
+ selectedFieldForLexicon = atoi(words[3]);
+ collapseFlag = atoi(words[4]);
+#endif
+
+ if (selectedField == -1)
+ cerr << "no selected field: the whole string is used" << std::endl;
+ else
+ cerr << "selected field n. " << selectedField << std::endl;
+ if (collapseFlag)
+ cerr << "collapse is enabled" << std::endl;
+ else
+ cerr << "collapse is disabled" << std::endl;
+
+
+ std::string lmfilename;
+ if (inp.getline(line,MAX_LINE,'\n')) {
+ tokenN = parseWords(line,words,MAX_TOKEN_N_MAP);
+ lmfilename = words[0];
+ } else
+ error((char*)"ERROR: wrong format of configuration file\ncorrect format: LMMACRO lmsize field [true|false]\nfilename_of_LM\nfilename_of_map (optional)");
+
+ std::string mapfilename = "";
+ if (inp.getline(line,MAX_LINE,'\n')) {
+ tokenN = parseWords(line,words,MAX_TOKEN_N_MAP);
+ mapfilename = words[0];
+ mapFlag = true;
+ } else {
+ mapFlag = false;
+ }
+
+ inp.close();
+
+
+ std::cerr << "lmfilename:" << lmfilename << std::endl;
+ if (mapfilename != "") {
+ std::cerr << "mapfilename:" << mapfilename << std::endl;
+ } else {
+ std::cerr << "no mapfilename" << std::endl;
+ mapFlag = false;
+ }
+
+ //allow the dictionary to add new words
+ getDict()->incflag(1);
+
+
+ if ((!mapFlag) && (collapseFlag)) {
+ error((char*)"ERROR: you must specify a map if you want to collapse a specific field!");
+ }
+#ifdef DLEXICALLM
+
+ std::string lexicalclassesfilename = words[2];
+ if (lexicalclassesfilename != "NULL" && lexicalclassesfilename != "null") lexicalclassesfilename = "";
+
+ if (lexicalclassesfilename != "") std::cerr << "lexicalclassesfilename:" << lexicalclassesfilename << std::endl;
+ else std::cerr << "no lexicalclassesfilename" << std::endl;
+
+ // Load the classes of lexicalization tokens:
+ if (lexicalclassesfilename != "") loadLexicalClasses(lexicalclassesfilename.c_str());
+#endif
+
+ // Load the (possibly binary) LM
+ lmtable::load(lmfilename,memmap);
+
+ getDict()->incflag(1);
+
+ if (mapFlag)
+ loadmap(mapfilename);
+ getDict()->genoovcode();
+
+};
+
+void lmmacro::unloadmap()
+{
+ delete dict;
+ free(microMacroMap);
+ if (collapseFlag) {
+ free(collapsableMap);
+ free(collapsatorMap);
+ }
+#ifdef DLEXICALLM
+ free(lexicaltoken2classMap);
+#endif
+}
+
+void lmmacro::loadmap(const std::string mapfilename)
+{
+ microMacroMapN = 0;
+ microMacroMap = NULL;
+ collapsableMap = NULL;
+ collapsatorMap = NULL;
+
+#ifdef DLEXICALLM
+ lexicaltoken2classMap = NULL;
+ lexicaltoken2classMapN = 0;
+#endif
+
+ microMacroMap = (int *)calloc(BUFSIZ, sizeof(int));
+ if (collapseFlag) {
+ collapsableMap = (bool *)calloc(BUFSIZ, sizeof(bool));
+ collapsatorMap = (bool *)calloc(BUFSIZ, sizeof(bool));
+ }
+
+
+ getDict()->genoovcode();
+ microMacroMap[microMacroMapN] = lmtable::getDict()->oovcode();
+ MY_ASSERT(microMacroMapN == getDict()->oovcode());
+ microMacroMapN++;
+
+
+ if (lmtable::getDict()->getcode(BOS_)==-1) {
+ lmtable::getDict()->incflag(1);
+ lmtable::getDict()->encode(BOS_);
+ lmtable::getDict()->incflag(0);
+ }
+
+ if (lmtable::getDict()->getcode(EOS_)==-1) {
+ lmtable::getDict()->incflag(1);
+ lmtable::getDict()->encode(EOS_);
+ lmtable::getDict()->incflag(0);
+ }
+
+ char line[MAX_LINE];
+ const char* words[MAX_TOKEN_N_MAP];
+ const char *macroW;
+ const char *microW;
+ int tokenN;
+ bool bos=false,eos=false;
+
+ // Load the dictionary of micro tags (to be put in "dict" of lmmacro class):
+ inputfilestream inpMap(mapfilename.c_str());
+ std::cerr << "Reading map " << mapfilename << "..." << std::endl;
+ while (inpMap.getline(line,MAX_LINE,'\n')) {
+ tokenN = parseWords(line,words,MAX_TOKEN_N_MAP);
+ if (tokenN != 2)
+ error((char*)"ERROR: wrong format of map file\n");
+ microW = words[0];
+ macroW = words[1];
+ int microW_c=getDict()->encode(microW);
+ VERBOSE(4, "microW gets the code:" << microW_c << std::endl);
+
+ if (microMacroMapN>0 && !(microMacroMapN % BUFSIZ)) {
+ microMacroMap = (int *)reallocf(microMacroMap, sizeof(int)*(BUFSIZ*(1+microMacroMapN/BUFSIZ)));
+ if (collapseFlag) {
+ //create supporting info for collapse
+
+ collapsableMap = (bool *)reallocf(collapsableMap, sizeof(bool)*(BUFSIZ*(1+microMacroMapN/BUFSIZ)));
+ collapsatorMap = (bool *)reallocf(collapsatorMap, sizeof(bool)*(BUFSIZ*(1+microMacroMapN/BUFSIZ)));
+ }
+ }
+ microMacroMap[microMacroMapN] = lmtable::getDict()->getcode(macroW);
+
+ if (collapseFlag) {
+
+ int len = strlen(microW)-1;
+ if (microW[len] == '(') {
+ collapsableMap[microMacroMapN] = false;
+ collapsatorMap[microMacroMapN] = true;
+ } else if (microW[len] == ')') {
+ collapsableMap[microMacroMapN] = true;
+ collapsatorMap[microMacroMapN] = false;
+ } else if (microW[len] == '+') {
+ collapsableMap[microMacroMapN] = true;
+ collapsatorMap[microMacroMapN] = true;
+ } else {
+ collapsableMap[microMacroMapN] = false;
+ collapsatorMap[microMacroMapN] = false;
+ }
+ }
+
+ if (!bos && !strcmp(microW,BOS_)) bos=true;
+ if (!eos && !strcmp(microW,EOS_)) eos=true;
+
+ VERBOSE(2,"\nmicroW = " << microW << "\n"
+ << "macroW = " << macroW << "\n"
+ << "microMacroMapN = " << microMacroMapN << "\n"
+ << "code of micro = " << getDict()->getcode(microW) << "\n"
+ << "code of macro = " << lmtable::getDict()->getcode(macroW) << "\n");
+
+ microMacroMapN++;
+ }
+
+ if ((microMacroMapN == 0) && (selectedField == -1))
+ error((char*)"ERROR: with no field selection, a map for the whole string is mandatory\n");
+
+ if (microMacroMapN>0) {
+ // Add <s>-><s> to map if missing
+ if (!bos) {
+ getDict()->encode(BOS_);
+ if (microMacroMapN && !(microMacroMapN%BUFSIZ))
+ microMacroMap = (int *)reallocf(microMacroMap, sizeof(int)*(microMacroMapN+BUFSIZ));
+ microMacroMap[microMacroMapN++] = lmtable::getDict()->getcode(BOS_);
+ }
+
+ // Add </s>-></s> to map if missing
+ if (!eos) {
+ getDict()->encode(EOS_);
+ if (microMacroMapN && !(microMacroMapN%BUFSIZ))
+ microMacroMap = (int *)reallocf(microMacroMap, sizeof(int)*(microMacroMapN+BUFSIZ));
+ microMacroMap[microMacroMapN++] = lmtable::getDict()->getcode(EOS_);
+ }
+ }
+ // getDict()->incflag(0);
+
+ VERBOSE(2,"oovcode(micro)=" << getDict()->oovcode() << "\n"
+ << "oovcode(macro)=" << lmtable::getDict()->oovcode() << "\n"
+ << "microMacroMapN = " << microMacroMapN << "\n"
+ << "macrodictsize = " << getDict()->size() << "\n"
+ << "microdictsize = " << lmtable::getDict()->size() << "\n");
+
+ IFVERBOSE(2) {
+ for (int i=0; i<microMacroMapN; i++) {
+ VERBOSE(2,"micro[" << getDict()->decode(i) << "] {"<< i << "} -> " << lmtable::getDict()->decode(microMacroMap[i]) << " {" << microMacroMap[i]<< "}" << "\n");
+ }
+ }
+ std::cerr << "...done\n";
+}
+
+
+double lmmacro::lprob(ngram micro_ng)
+{
+ VERBOSE(2,"lmmacro::lprob, parameter = <" << micro_ng << ">\n");
+
+ ngram macro_ng(lmtable::getDict());
+
+ if (micro_ng.dict == macro_ng.dict)
+ macro_ng.trans(micro_ng); // micro to macro mapping already done
+ else
+ map(µ_ng, ¯o_ng); // mapping required
+
+ VERBOSE(3,"lmmacro::lprob: micro_ng = " << micro_ng << "\n"
+ << "lmmacro::lprob: macro_ng = " << macro_ng << "\n");
+
+ // ask LM with macro
+ double prob;
+ prob = lmtable::lprob(macro_ng);
+ VERBOSE(3,"prob = " << prob << "\n");
+
+ return prob;
+};
+
+double lmmacro::clprob(int* codes, int sz, double* bow, int* bol, char** state,unsigned int* statesize,bool* extendible)
+{
+ ngram micro_ng(getDict());
+ micro_ng.pushc(codes,sz);
+ return clprob(micro_ng,bow,bol,state,statesize,extendible);
+}
+
+double lmmacro::clprob(ngram micro_ng, double* bow, int* bol, char** state,unsigned int* statesize,bool* extendible)
+{
+
+ VERBOSE(3," lmmacro::clprob(ngram), parameter = <" << micro_ng << ">\n");
+
+ ngram transformed_ng(lmtable::getDict());
+ bool collapsed = transform(micro_ng, transformed_ng);
+ VERBOSE(3,"lmmacro::clprob(ngram), transformed_ng = <" << transformed_ng << ">\n");
+
+ double logpr;
+
+ if (collapsed) {
+ // the last token of the ngram continues an already open "chunk"
+ // the probability at chunk-level is not computed because it has been already computed when the actual"chunk" opens
+ VERBOSE(3," SKIPPED call to lmtable::clprob because of collapse; logpr: 0.0\n");
+ logpr = 0.0;
+ } else {
+ VERBOSE(3," QUERY MACRO LM on (after transformation and size reduction) " << transformed_ng << "\n");
+ logpr = lmtable::clprob(transformed_ng, bow, bol, state, statesize, extendible);
+ }
+ VERBOSE(3," GET logpr: " << logpr << "\n");
+
+ return logpr;
+}
+
+bool lmmacro::transform(ngram &in, ngram &out)
+{
+ VERBOSE(3,"lmmacro::transform(ngram &in, ngram &out), in = <" << in << ">\n");
+
+ //step 1: selection of the correct field
+ ngram field_ng(getDict());
+ if (selectedField >= 0)
+ field_selection(in, field_ng);
+ else
+ field_ng = in;
+
+ //step 2: collapsing
+ ngram collapsed_ng(getDict());
+ bool collapsed = false;
+ if (collapseFlag)
+ collapsed = collapse(field_ng, collapsed_ng);
+ else
+ collapsed_ng = field_ng;
+
+ //step 3: mapping using the loaded map
+ if (mapFlag)
+ mapping(collapsed_ng, out);
+ else
+ out.trans(collapsed_ng);
+
+ if (out.size>lmtable::maxlevel()) out.size=lmtable::maxlevel();
+
+ VERBOSE(3,"lmmacro::transform(ngram &in, ngram &out), out = <" << out << ">\n");
+ return collapsed;
+}
+
+
+
+void lmmacro::field_selection(ngram &in, ngram &out)
+{
+ VERBOSE(3,"In lmmacro::field_selection(ngram &in, ngram &out) in = " << in << "\n");
+
+ int microsize = in.size;
+
+ for (int i=microsize; i>0; i--) {
+
+ char curr_token[BUFSIZ];
+ strcpy(curr_token, getDict()->decode(*in.wordp(i)));
+ char *field;
+ if (strcmp(curr_token,"<s>") &&
+ strcmp(curr_token,"</s>") &&
+ strcmp(curr_token,"_unk_")) {
+ field = strtok(curr_token, "#");
+ int j=0;
+ while (j<selectedField && field != NULL) {
+ field = strtok(0, "#");
+ j++;
+ }
+ } else {
+ field = curr_token;
+ }
+
+
+ if (field) {
+ out.pushw(field);
+ } else {
+
+ out.pushw((char*)"_unk_");
+
+ // cerr << *in << "\n";
+ // error((char*)"ERROR: Malformed input: selected field does not exist in token\n");
+
+ /**
+ We can be here in 2 cases:
+
+ a. effectively when the token is malformed, that is the selected
+ field does not exist
+
+ b. in case of verbatim translation, that is the source word is
+ not known to the phrase table and moses transfers it as it is
+ to the target side: in this case, no assumption can be made on its
+ format, which means that the selected field can not exist
+
+ The possibility of case (b) makes incorrect the error exit from
+ the code at this point: correct, on the contrary, push the _unk_ string
+ **/
+ }
+ }
+ VERBOSE(3,"In lmmacro::field_selection(ngram &in, ngram &out) out = " << out << "\n");
+ return;
+}
+
+bool lmmacro::collapse(ngram &in, ngram &out)
+{
+ VERBOSE(3,"In lmmacro::collapse(ngram &in, ngram &out) in = " << in << "\n")
+
+ // fill the ngram out with the collapsed tokens
+ //return true if collapse happens for the most recent token
+ //return false if collapse does not happen for the most recent token
+ int microsize = in.size;
+ out.size = 0;
+
+ if (microsize == 1) {
+ out.pushc(*in.wordp(1));
+ return false;
+ }
+
+ int curr_code = *in.wordp(1);
+ int prev_code = *in.wordp(2);
+
+ if (microMacroMap[curr_code] == microMacroMap[prev_code]) {
+ if (collapsableMap[curr_code] && collapsatorMap[prev_code]) {
+ return true;
+ }
+ }
+
+ //collapse does not happen for the most recent token
+ // collapse all previous tokens, but the last
+
+ prev_code = *in.wordp(microsize);
+ out.pushc(prev_code);
+
+ for (int i=microsize-1; i>1; i--) {
+
+ curr_code = *in.wordp(i);
+
+ if (microMacroMap[curr_code] != microMacroMap[prev_code]) {
+ out.pushc(curr_code);
+ } else {
+ if (!(collapsableMap[curr_code] && collapsatorMap[prev_code])) {
+ out.pushc(prev_code);
+ }
+ }
+ prev_code = curr_code;
+ }
+ // and insert the most recent token
+ out.pushc(*in.wordp(1));
+ VERBOSE(3,"In lmmacro::collapse(ngram &in, ngram &out) out = " << out << "\n");
+ return false;
+}
+
+void lmmacro::mapping(ngram &in, ngram &out)
+{
+ VERBOSE(3,"In lmmacro::mapping(ngram &in, ngram &out) in = " << in << "\n");
+
+ int microsize = in.size;
+
+ // map microtag sequence (in) into the corresponding sequence of macrotags (possibly shorter) (out)
+
+ for (int i=microsize; i>0; i--) {
+
+ int in_code = *in.wordp(i);
+ int out_code;
+ if (in_code < microMacroMapN)
+ out_code = microMacroMap[in_code];
+ else
+ out_code = lmtable::getDict()->oovcode();
+
+ out.pushc(out_code);
+ }
+ VERBOSE(3,"In lmmacro::mapping(ngram &in, ngram &out) out = " << out << "\n");
+ return;
+}
+
+
+//maxsuffptr returns the largest suffix of an n-gram that is contained
+//in the LM table. This can be used as a compact representation of the
+//(n-1)-gram state of a n-gram LM. if the input k-gram has k>=n then it
+//is trimmed to its n-1 suffix.
+
+const char *lmmacro::maxsuffptr(ngram micro_ng, unsigned int* size)
+{
+ ngram macro_ng(lmtable::getDict());
+
+ if (micro_ng.dict == macro_ng.dict)
+ macro_ng.trans(micro_ng); // micro to macro mapping already done
+ else
+ map(µ_ng, ¯o_ng); // mapping required
+
+ VERBOSE(2,"lmmacro::lprob: micro_ng = " << micro_ng << "\n"
+ << "lmmacro::lprob: macro_ng = " << macro_ng << "\n");
+
+ return lmtable::maxsuffptr(macro_ng,size);
+}
+
+const char *lmmacro::cmaxsuffptr(ngram micro_ng, unsigned int* size)
+{
+ //cerr << "lmmacro::CMAXsuffptr\n";
+ //cerr << "micro_ng: " << micro_ng
+ // << " -> micro_ng.size: " << micro_ng.size << "\n";
+
+ //the LM working on the selected field = 0
+ //contributes to the LM state
+ // if (selectedField>0) return NULL;
+
+ ngram macro_ng(lmtable::getDict());
+
+ if (micro_ng.dict == macro_ng.dict)
+ macro_ng.trans(micro_ng); // micro to macro mapping already done
+ else
+ map(µ_ng, ¯o_ng); // mapping required
+
+ VERBOSE(2,"lmmacro::lprob: micro_ng = " << micro_ng << "\n"
+ << "lmmacro::lprob: macro_ng = " << macro_ng << "\n")
+
+ return lmtable::cmaxsuffptr(macro_ng,size);
+
+}
+
+
+void lmmacro::map(ngram *in, ngram *out)
+{
+
+ VERBOSE(2,"In lmmacro::map, in = " << *in << endl
+ << " (selectedField = " << selectedField << " )\n");
+
+ if (selectedField==-2) // the whole token is compatible with the LM words
+ One2OneMapping(in, out);
+
+ else if (selectedField==-1) // the whole token has to be mapped before querying the LM
+ Micro2MacroMapping(in, out);
+
+ else if (selectedField<10) { // select the field "selectedField" from tokens (separator is assumed to be "#")
+ ngram field_ng(((lmmacro *)this)->getDict());
+ int microsize = in->size;
+
+ for (int i=microsize; i>0; i--) {
+
+ char curr_token[BUFSIZ];
+ strcpy(curr_token, ((lmmacro *)this)->getDict()->decode(*(in->wordp(i))));
+ char *field;
+ if (strcmp(curr_token,"<s>") &&
+ strcmp(curr_token,"</s>") &&
+ strcmp(curr_token,"_unk_")) {
+ field = strtok(curr_token, "#");
+ int j=0;
+ while (j<selectedField && field != NULL) {
+ field = strtok(0, "#");
+ j++;
+ }
+ } else {
+ field = curr_token;
+ }
+
+ if (field)
+ field_ng.pushw(field);
+ else {
+
+ field_ng.pushw((char*)"_unk_");
+
+ // cerr << *in << "\n";
+ // error((char*)"ERROR: Malformed input: selected field does not exist in token\n");
+
+ /**
+ We can be here in 2 cases:
+
+ a. effectively when the token is malformed, that is the selected
+ field does not exist
+
+ b. in case of verbatim translation, that is the source word is
+ not known to the phrase table and moses transfers it as it is
+ to the target side: in this case, no assumption can be made on its
+ format, which means that the selected field can not exist
+
+ The possibility of case (b) makes incorrect the error exit from
+ the code at this point: correct, on the contrary, push the _unk_ string
+ **/
+ }
+ }
+ if (microMacroMapN>0)
+ Micro2MacroMapping(&field_ng, out);
+ else
+ out->trans(field_ng);
+ } else {
+
+#ifdef DLEXICALLM
+ // selectedField>=10: tens=idx of micro tag (possibly to be mapped to
+ // macro tag), unidx=idx of lemma to be concatenated by "_" to the
+ // (mapped) tag
+
+ int tagIdx = selectedField/10;
+ int lemmaIdx = selectedField%10;
+
+ // micro (or mapped to macro) sequence construction:
+ ngram tag_ng(getDict());
+ char *lemmas[BUFSIZ];
+
+ int microsize = in->size;
+ for (int i=microsize; i>0; i--) {
+ char curr_token[BUFSIZ];
+ strcpy(curr_token, getDict()->decode(*(in->wordp(i))));
+ char *tag = NULL, *lemma = NULL;
+
+ if (strcmp(curr_token,"<s>") &&
+ strcmp(curr_token,"</s>") &&
+ strcmp(curr_token,"_unk_")) {
+
+ if (tagIdx<lemmaIdx) {
+ tag = strtok(curr_token, "#");
+ for (int j=0; j<tagIdx; j++)
+ tag = strtok(0, "#");
+ for (int j=tagIdx; j<lemmaIdx; j++)
+ lemma = strtok(0, "#");
+ } else {
+ lemma = strtok(curr_token, "#");
+ for (int j=0; j<lemmaIdx; j++)
+ lemma = strtok(0, "#");
+ for (int j=lemmaIdx; j<tagIdx; j++)
+ tag = strtok(0, "#");
+ }
+
+ VERBOSE(3,"(tag,lemma) = " << tag << " " << lemma << "\n");
+ } else {
+ tag = curr_token;
+ lemma = curr_token;
+ VERBOSE(3,"(tag=lemma) = " << tag << " " << lemma << "\n");
+ }
+ if (tag) {
+ tag_ng.pushw(tag);
+ lemmas[i] = strdup(lemma);
+ } else {
+ tag_ng.pushw((char*)"_unk_");
+ lemmas[i] = strdup("_unk_");
+ }
+ }
+
+ if (microMacroMapN>0)
+ Micro2MacroMapping(&tag_ng, out, lemmas);
+ else
+ out->trans(tag_ng); // qui si dovrebbero sostituire i tag con tag_lemma, senza mappatura!
+
+#endif
+
+ }
+
+ VERBOSE(2,"In lmmacro::map, FINAL out = " << *out << endl);
+}
+
+void lmmacro::One2OneMapping(ngram *in, ngram *out)
+{
+ int insize = in->size;
+
+ // map each token of the sequence "in" into the same-length sequence "out" through the map
+
+ for (int i=insize; i>0; i--) {
+
+ int curr_code = *(in->wordp(i));
+ const char *outtoken =
+ lmtable::getDict()->decode((curr_code<microMacroMapN)?microMacroMap[curr_code]:lmtable::getDict()->oovcode());
+ out->pushw(outtoken);
+ }
+ return;
+}
+
+
+void lmmacro::Micro2MacroMapping(ngram *in, ngram *out)
+{
+
+ int microsize = in->size;
+
+ VERBOSE(2,"In Micro2MacroMapping, in = " << *in << "\n");
+
+ // map microtag sequence (in) into the corresponding sequence of macrotags (possibly shorter) (out)
+
+ for (int i=microsize; i>0; i--) {
+
+ int curr_code = *(in->wordp(i));
+ const char *curr_macrotag = lmtable::getDict()->decode((curr_code<microMacroMapN)?microMacroMap[curr_code]:lmtable::getDict()->oovcode());
+
+ if (i==microsize) {
+ out->pushw(curr_macrotag);
+
+ } else {
+ int prev_code = *(in->wordp(i+1));
+
+ const char *prev_microtag = getDict()->decode(prev_code);
+ const char *curr_microtag = getDict()->decode(curr_code);
+ const char *prev_macrotag = lmtable::getDict()->decode((prev_code<microMacroMapN)?microMacroMap[prev_code]:lmtable::getDict()->oovcode());
+
+
+ int prev_len = strlen(prev_microtag)-1;
+ int curr_len = strlen(curr_microtag)-1;
+
+ if (strcmp(curr_macrotag,prev_macrotag) != 0 ||
+ !(
+ (( prev_microtag[prev_len]== '(' || ( prev_microtag[0]== '(' && prev_microtag[prev_len]!= ')' )) && ( curr_microtag[curr_len]==')' && curr_microtag[0]!='(')) ||
+ (( prev_microtag[prev_len]== '(' || ( prev_microtag[0]== '(' && prev_microtag[prev_len]!= ')' )) && curr_microtag[curr_len]=='+' ) ||
+ (prev_microtag[prev_len]== '+' && curr_microtag[curr_len]=='+' ) ||
+ (prev_microtag[prev_len]== '+' && ( curr_microtag[curr_len]==')' && curr_microtag[0]!='(' ))))
+ out->pushw(curr_macrotag);
+ }
+ }
+ return;
+}
+
+
+
+// DISMITTED ON FEB 2011 BECAUSE TOO MUCH PROBLEMATIC FROM A THEORETICAL POINT OF VIEW
+
+#ifdef DLEXICALLM
+
+void lmmacro::Micro2MacroMapping(ngram *in, ngram *out, char **lemmas)
+{
+ VERBOSE(2,"In Micro2MacroMapping, in = " << *in << "\n")
+
+ int microsize = in->size;
+
+ IFVERBOSE(3) {
+ VERBOSE(3,"In Micro2MacroMapping, lemmas:\n");
+ if (lexicaltoken2classMap)
+ for (int i=microsize; i>0; i--)
+ VERBOSE(3,"lemmas[" << i << "]=" << lemmas[i] << " -> class -> " << lexicaltoken2classMap[lmtable::getDict()->encode(lemmas[i])] << endl);
+ else
+ for (int i=microsize; i>0; i--)
+ VERBOSE(3,"lemmas[" << i << "]=" << lemmas[i] << endl);
+ }
+
+ // map microtag sequence (in) into the corresponding sequence of macrotags (possibly shorter) (out)
+
+ char tag_lemma[BUFSIZ];
+
+ for (int i=microsize; i>0; i--) {
+
+ int curr_code = *(in->wordp(i));
+
+ const char *curr_microtag = getDict()->decode(curr_code);
+ const char *curr_lemma = lemmas[i];
+ const char *curr_macrotag = lmtable::getDict()->decode((curr_code<microMacroMapN)?microMacroMap[curr_code]:lmtable::getDict()->oovcode());
+ int curr_len = strlen(curr_microtag)-1;
+
+ if (i==microsize) {
+ if (( curr_microtag[curr_len]=='(' ) || ( curr_microtag[0]=='(' && curr_microtag[curr_len]!=')' ) || ( curr_microtag[curr_len]=='+' ))
+ sprintf(tag_lemma, "%s", curr_macrotag); // non lessicalizzo il macrotag se sono ancora all''interno del chunk
+ else if (lexicaltoken2classMap)
+ sprintf(tag_lemma, "%s_class%d", curr_macrotag, lexicaltoken2classMap[lmtable::getDict()->encode(curr_lemma)]);
+ else
+ sprintf(tag_lemma, "%s_%s", curr_macrotag, lemmas[microsize]);
+
+ VERBOSE(2,"In Micro2MacroMapping, starting tag_lemma = >" << tag_lemma << "<\n");
+
+ out->pushw(tag_lemma);
+ free(lemmas[microsize]);
+
+
+ } else {
+
+ int prev_code = *(in->wordp(i+1));
+ const char *prev_microtag = getDict()->decode(prev_code);
+ const char *prev_macrotag = lmtable::getDict()->decode((prev_code<microMacroMapN)?microMacroMap[prev_code]:lmtable::getDict()->oovcode());
+
+
+ int prev_len = strlen(prev_microtag)-1;
+
+ if (( curr_microtag[curr_len]=='(' ) || ( curr_microtag[0]=='(' && curr_microtag[curr_len]!=')' ) || ( curr_microtag[curr_len]=='+' ))
+ sprintf(tag_lemma, "%s", curr_macrotag); // non lessicalizzo il macrotag se sono ancora all''interno del chunk
+ else if (lexicaltoken2classMap)
+ sprintf(tag_lemma, "%s_class%d", curr_macrotag, lexicaltoken2classMap[lmtable::getDict()->encode(curr_lemma)]);
+ else
+ sprintf(tag_lemma, "%s_%s", curr_macrotag, curr_lemma);
+
+ VERBOSE(2,"In Micro2MacroMapping, tag_lemma = >" << tag_lemma << "<\n");
+
+ if (strcmp(curr_macrotag,prev_macrotag) != 0 ||
+ !(
+ (( prev_microtag[prev_len]== '(' || ( prev_microtag[0]== '(' && prev_microtag[prev_len]!=')' )) && curr_microtag[curr_len]==')' && curr_microtag[0]!='(') ||
+ (( prev_microtag[prev_len]== '(' || ( prev_microtag[0]== '(' && prev_microtag[prev_len]!= ')')) && curr_microtag[curr_len]=='+' ) ||
+ (prev_microtag[prev_len]== '+' && curr_microtag[curr_len]=='+' ) ||
+ (prev_microtag[prev_len]== '+' && curr_microtag[curr_len]==')' && curr_microtag[0]!='(' ))) {
+
+ VERBOSE(2,"In Micro2MacroMapping, before pushw, out = " << *out << endl);
+ out->pushw(tag_lemma);
+ VERBOSE(2,"In Micro2MacroMapping, after pushw, out = " << *out << endl);
+ } else {
+ VERBOSE(2,"In Micro2MacroMapping, before shift, out = " << *out << endl);
+ out->shift();
+ VERBOSE(2,"In Micro2MacroMapping, after shift, out = " << *out << endl);
+ out->pushw(tag_lemma);
+ VERBOSE(2,"In Micro2MacroMapping, after push, out = " << *out << endl);
+ }
+ free(lemmas[i]);
+ }
+ }
+ return;
+}
+
+void lmmacro::loadLexicalClasses(const char *fn)
+{
+ char line[MAX_LINE];
+ const char* words[MAX_TOKEN_N_MAP];
+ int tokenN;
+
+ lexicaltoken2classMap = (int *)calloc(BUFSIZ, sizeof(int));
+ lexicaltoken2classMapN = BUFSIZ;
+
+ lmtable::getDict()->incflag(1);
+
+ inputfilestream inp(fn);
+ while (inp.getline(line,MAX_LINE,'\n')) {
+ tokenN = parseWords(line,words,MAX_TOKEN_N_MAP);
+ if (tokenN != 2)
+ error((char*)"ERROR: wrong format of lexical classes file\n");
+ else {
+ int classIdx = atoi(words[1]);
+ int wordCode = lmtable::getDict()->encode(words[0]);
+
+ if (wordCode>=lexicaltoken2classMapN) {
+ int r = (wordCode-lexicaltoken2classMapN)/BUFSIZ;
+ lexicaltoken2classMapN += (r+1)*BUFSIZ;
+ lexicaltoken2classMap = (int *)reallocf(lexicaltoken2classMap, sizeof(int)*lexicaltoken2classMapN);
+ }
+ lexicaltoken2classMap[wordCode] = classIdx;
+ }
+ }
+
+ lmtable::getDict()->incflag(0);
+
+ IFVERBOSE(3) {
+ for (int x=0; x<lmtable::getDict()->size(); x++)
+ VERBOSE(3,"class of <" << lmtable::getDict()->decode(x) << "> (code=" << x << ") = " << lexicaltoken2classMap[x] << endl);
+ }
+
+ return;
+}
+
+
+void lmmacro::cutLex(ngram *in, ngram *out)
+{
+ *out=*in;
+
+ const char *curr_macro = out->dict->decode(*(out->wordp(1)));
+ out->shift();
+ const char *p = strrchr(curr_macro, '_');
+ int lexLen;
+ if (p)
+ lexLen=strlen(p);
+ else
+ lexLen=0;
+ char curr_NoLexMacro[BUFSIZ];
+ memset(&curr_NoLexMacro,0,BUFSIZ);
+ strncpy(curr_NoLexMacro,curr_macro,strlen(curr_macro)-lexLen);
+ out->pushw(curr_NoLexMacro);
+ return;
+}
+#endif
+
+}//namespace irstlm
diff --git a/src/lmmacro.h b/src/lmmacro.h
new file mode 100644
index 0000000..bfeab6d
--- /dev/null
+++ b/src/lmmacro.h
@@ -0,0 +1,133 @@
+// $Id: lmmacro.h 3461 2010-08-27 10:17:34Z bertoldi $
+
+/******************************************************************************
+IrstLM: IRST Language Model Toolkit
+Copyright (C) 2006 Marcello Federico, ITC-irst Trento, Italy
+
+This library is free software; you can redistribute it and/or
+modify it under the terms of the GNU Lesser General Public
+License as published by the Free Software Foundation; either
+version 2.1 of the License, or (at your option) any later version.
+
+This library is distributed in the hope that it will be useful,
+but WITHOUT ANY WARRANTY; without even the implied warranty of
+MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+Lesser General Public License for more details.
+
+You should have received a copy of the GNU Lesser General Public
+License along with this library; if not, write to the Free Software
+Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+
+******************************************************************************/
+
+
+#ifndef MF_LMMACRO_H
+#define MF_LMMACRO_H
+
+#ifndef WIN32
+#include <sys/types.h>
+#include <sys/mman.h>
+#endif
+
+#include "util.h"
+#include "ngramcache.h"
+#include "dictionary.h"
+#include "n_gram.h"
+#include "lmtable.h"
+
+#define MAX_TOKEN_N_MAP 5
+
+namespace irstlm {
+
+class lmmacro: public lmtable
+{
+
+ dictionary *dict;
+ int maxlev; //max level of table
+ int selectedField;
+
+ bool collapseFlag; //flag for the presence of collapse
+ bool mapFlag; //flag for the presence of map
+
+ int microMacroMapN;
+ int *microMacroMap;
+ bool *collapsableMap;
+ bool *collapsatorMap;
+
+#ifdef DLEXICALLM
+ int selectedFieldForLexicon;
+ int *lexicaltoken2classMap;
+ int lexicaltoken2classMapN;
+#endif
+
+
+ void loadmap(const std::string mapfilename);
+ void unloadmap();
+
+ bool transform(ngram &in, ngram &out);
+ void field_selection(ngram &in, ngram &out);
+ bool collapse(ngram &in, ngram &out);
+ void mapping(ngram &in, ngram &out);
+
+public:
+
+ lmmacro(float nlf=0.0, float dlfi=0.0);
+ ~lmmacro();
+
+ void load(const std::string &filename,int mmap=0);
+
+ double lprob(ngram ng);
+ double clprob(ngram ng,double* bow=NULL,int* bol=NULL,char** maxsuffptr=NULL,unsigned int* statesize=NULL,bool* extendible=NULL);
+ double clprob(int* ng, int ngsize, double* bow=NULL,int* bol=NULL,char** maxsuffptr=NULL,unsigned int* statesize=NULL,bool* extendible=NULL);
+
+ const char *maxsuffptr(ngram ong, unsigned int* size=NULL);
+ const char *cmaxsuffptr(ngram ong, unsigned int* size=NULL);
+
+ void map(ngram *in, ngram *out);
+ void One2OneMapping(ngram *in, ngram *out);
+ void Micro2MacroMapping(ngram *in, ngram *out);
+#ifdef DLEXICALLM
+ void Micro2MacroMapping(ngram *in, ngram *out, char **lemma);
+ void loadLexicalClasses(const char *fn);
+ void cutLex(ngram *in, ngram *out);
+#endif
+
+ inline bool is_OOV(int code) {
+ ngram word_ng(getDict());
+ ngram field_ng(getDict());
+ word_ng.pushc(code);
+ if (selectedField >= 0)
+ field_selection(word_ng, field_ng);
+ else
+ field_ng = word_ng;
+ int field_code=*field_ng.wordp(1);
+ VERBOSE(2,"inline virtual bool lmmacro::is_OOV(int code) word_ng:" << word_ng << " field_ng:" << field_ng << std::endl);
+ //the selected field(s) of a token is considered OOV
+ //either if unknown by the microMacroMap
+ //or if its mapped macroW is OOV
+ if (field_code >= microMacroMapN) return true;
+ VERBOSE(2,"inline virtual bool lmmacro::is_OOV(int code)*field_code:" << field_code << " microMacroMap[field_code]:" << microMacroMap[field_code] << " lmtable::dict->oovcode():" << lmtable::dict->oovcode() << std::endl);
+ return (microMacroMap[field_code] == lmtable::dict->oovcode());
+ };
+ inline dictionary* getDict() const {
+ return dict;
+ }
+ inline int maxlevel() const {
+ return maxlev;
+ };
+
+ inline virtual void dictionary_incflag(const bool flag) {
+ dict->incflag(flag);
+ };
+
+ inline virtual bool filter(const string sfilter, lmContainer* sublmt, const string skeepunigrams) {
+ UNUSED(sfilter);
+ UNUSED(sublmt);
+ UNUSED(skeepunigrams);
+ return false;
+ }
+};
+
+}//namespace irstlm
+#endif
+
diff --git a/src/lmtable.cpp b/src/lmtable.cpp
new file mode 100644
index 0000000..84c31dd
--- /dev/null
+++ b/src/lmtable.cpp
@@ -0,0 +1,2948 @@
+// $Id: lmtable.cpp 3686 2010-10-15 11:55:32Z bertoldi $
+
+/******************************************************************************
+ IrstLM: IRST Language Model Toolkit
+ Copyright (C) 2006 Marcello Federico, ITC-irst Trento, Italy
+
+ This library is free software; you can redistribute it and/or
+ modify it under the terms of the GNU Lesser General Public
+ License as published by the Free Software Foundation; either
+ version 2.1 of the License, or (at your option) any later version.
+
+ This library is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+ Lesser General Public License for more details.
+
+ You should have received a copy of the GNU Lesser General Public
+ License along with this library; if not, write to the Free Software
+ Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+
+ ******************************************************************************/
+
+#include <stdio.h>
+#include <cstdlib>
+#include <stdlib.h>
+#include <fcntl.h>
+#include <iostream>
+#include <fstream>
+#include <stdexcept>
+#include <string>
+#include <set>
+#include "math.h"
+#include "mempool.h"
+#include "htable.h"
+#include "ngramcache.h"
+#include "dictionary.h"
+#include "n_gram.h"
+#include "lmContainer.h"
+#include "lmtable.h"
+#include "util.h"
+
+//special value for pruned iprobs
+#define NOPROB ((float)-1.329227995784915872903807060280344576e36)
+
+using namespace std;
+
+inline void error(const char* message)
+{
+ VERBOSE(2,message << std::endl);
+ throw std::runtime_error(message);
+}
+
+void print(prob_and_state_t* pst, std::ostream& out)
+{
+ if (pst != NULL) {
+ out << "PST [";
+ out << "logpr:" << pst->logpr;
+ out << ",state:" << (void*) pst->state;
+ out << ",statesize:" << pst->statesize;
+ out << ",bow:" << pst->bow;
+ out << ",bol:" << pst->bol;
+ out << "]";
+ out << std::endl;
+ } else {
+ out << "PST [NULL]" << std::endl;
+ }
+}
+
+namespace irstlm {
+
+ //instantiate an empty lm table
+ lmtable::lmtable(float nlf, float dlf):lmContainer()
+ {
+ ngramcache_load_factor = nlf;
+ dictionary_load_factor = dlf;
+ isInverted=false;
+ configure(1,false);
+
+ dict=new dictionary((char *)NULL,1000000,dictionary_load_factor);
+ delete_dict=true;
+
+ memset(table, 0, sizeof(table));
+ memset(tableGaps, 0, sizeof(tableGaps));
+ memset(cursize, 0, sizeof(cursize));
+ memset(tbltype, 0, sizeof(tbltype));
+ memset(maxsize, 0, sizeof(maxsize));
+ memset(tb_offset, 0, sizeof(maxsize));
+ memset(info, 0, sizeof(info));
+ memset(NumCenters, 0, sizeof(NumCenters));
+
+ max_cache_lev=0;
+ for (int i=0; i<LMTMAXLEV+1; i++) lmtcache[i]=NULL;
+ for (int i=0; i<LMTMAXLEV+1; i++) prob_and_state_cache[i]=NULL;
+ // prob_and_state_cache=NULL;
+
+#ifdef TRACE_CACHELM
+ //cacheout=new std::fstream(get_temp_folder()++"tracecache",std::ios::out);
+ cacheout=new std::fstream("/tmp/tracecache",std::ios::out);
+ sentence_id=0;
+#endif
+
+ memmap=0;
+ requiredMaxlev=1000;
+
+ isPruned=false;
+ isInverted=false;
+
+ //statistics
+ for (int i=0; i<=LMTMAXLEV+1; i++) totget[i]=totbsearch[i]=0;
+
+ logOOVpenalty=0.0; //penalty for OOV words (default 0)
+
+ // by default, it is a standard LM, i.e. queried for score
+ setOrderQuery(false);
+ };
+
+ lmtable::~lmtable()
+ {
+ delete_caches();
+
+#ifdef TRACE_CACHELM
+ cacheout->close();
+ delete cacheout;
+#endif
+
+ for (int l=1; l<=maxlev; l++) {
+ if (table[l]) {
+ if (memmap > 0 && l >= memmap)
+ Munmap(table[l]-tableGaps[l],cursize[l]*nodesize(tbltype[l])+tableGaps[l],0);
+ else
+ delete [] table[l];
+ }
+ if (isQtable) {
+ if (Pcenters[l]) delete [] Pcenters[l];
+ if (l<maxlev)
+ if (Bcenters[l]) delete [] Bcenters[l];
+ }
+ }
+
+ if (delete_dict) delete dict;
+ };
+
+ void lmtable::init_prob_and_state_cache()
+ {
+#ifdef PS_CACHE_ENABLE
+ for (int i=1; i<=max_cache_lev; i++)
+ {
+ MY_ASSERT(prob_and_state_cache[i]==NULL);
+ prob_and_state_cache[i]=new NGRAMCACHE_t(i,sizeof(prob_and_state_t),400000,ngramcache_load_factor); // initial number of entries is 400000
+ VERBOSE(2,"creating cache for storing prob, state and statesize of size " << i << std::endl);
+ }
+#endif
+ }
+
+ // void lmtable::init_lmtcaches(int uptolev)
+ void lmtable::init_lmtcaches()
+ {
+#ifdef LMT_CACHE_ENABLE
+ for (int i=2; i<=max_cache_lev; i++)
+ {
+ MY_ASSERT(lmtcache[i]==NULL);
+ lmtcache[i]=new NGRAMCACHE_t(i,sizeof(char*),200000,ngramcache_load_factor); // initial number of entries is 200000
+ }
+#endif
+ }
+
+ void lmtable::init_caches(int uptolev)
+ {
+ max_cache_lev=uptolev;
+#ifdef PS_CACHE_ENABLE
+ init_prob_and_state_cache();
+#endif
+#ifdef LMT_CACHE_ENABLE
+ init_lmtcaches();
+#endif
+ }
+
+ void lmtable::delete_prob_and_state_cache()
+ {
+#ifdef PS_CACHE_ENABLE
+ for (int i=1; i<=max_cache_lev; i++)
+ {
+ if (prob_and_state_cache[i])
+ {
+ delete prob_and_state_cache[i];
+ }
+ prob_and_state_cache[i]=NULL;
+ }
+#endif
+ }
+
+ void lmtable::delete_lmtcaches()
+ {
+#ifdef LMT_CACHE_ENABLE
+ for (int i=2; i<=max_cache_lev; i++)
+ {
+ if (lmtcache[i])
+ {
+ delete lmtcache[i];
+ }
+ lmtcache[i]=NULL;
+ }
+#endif
+ }
+
+ void lmtable::delete_caches()
+ {
+#ifdef PS_CACHE_ENABLE
+ delete_prob_and_state_cache();
+#endif
+#ifdef LMT_CACHE_ENABLE
+ delete_lmtcaches();
+#endif
+ }
+
+ void lmtable::stat_prob_and_state_cache()
+ {
+#ifdef PS_CACHE_ENABLE
+ for (int i=1; i<=max_cache_lev; i++)
+ {
+ std::cout << "void lmtable::stat_prob_and_state_cache() level:" << i << std::endl;
+ if (prob_and_state_cache[i])
+ {
+ prob_and_state_cache[i]->stat();
+ }
+ }
+#endif
+ }
+ void lmtable::stat_lmtcaches()
+ {
+#ifdef PS_CACHE_ENABLE
+ for (int i=2; i<=max_cache_lev; i++)
+ {
+ std::cout << "void lmtable::stat_lmtcaches() level:" << i << std::endl;
+ if (lmtcache[i])
+ {
+ lmtcache[i]->stat();
+ }
+ }
+#endif
+ }
+
+ void lmtable::stat_caches()
+ {
+#ifdef PS_CACHE_ENABLE
+ stat_prob_and_state_cache();
+#endif
+#ifdef LMT_CACHE_ENABLE
+ stat_lmtcaches();
+#endif
+ }
+
+
+ void lmtable::used_prob_and_state_cache() const
+ {
+#ifdef PS_CACHE_ENABLE
+ for (int i=1; i<=max_cache_lev; i++)
+ {
+ if (prob_and_state_cache[i])
+ {
+ prob_and_state_cache[i]->used();
+ }
+ }
+#endif
+ }
+
+ void lmtable::used_lmtcaches() const
+ {
+#ifdef LMT_CACHE_ENABLE
+ for (int i=2; i<=max_cache_lev; i++)
+ {
+ if (lmtcache[i])
+ {
+ lmtcache[i]->used();
+ }
+ }
+#endif
+ }
+
+ void lmtable::used_caches() const
+ {
+#ifdef PS_CACHE_ENABLE
+ used_prob_and_state_cache();
+#endif
+#ifdef LMT_CACHE_ENABLE
+ used_lmtcaches();
+#endif
+ }
+
+
+ void lmtable::check_prob_and_state_cache_levels() const
+ {
+#ifdef PS_CACHE_ENABLE
+ for (int i=1; i<=max_cache_lev; i++)
+ {
+ if (prob_and_state_cache[i] && prob_and_state_cache[i]->isfull())
+ {
+ prob_and_state_cache[i]->reset(prob_and_state_cache[i]->cursize());
+ }
+ }
+#endif
+ }
+
+ void lmtable::check_lmtcaches_levels() const
+ {
+#ifdef LMT_CACHE_ENABLE
+ for (int i=2; i<=max_cache_lev; i++)
+ {
+ if (lmtcache[i] && lmtcache[i]->isfull())
+ {
+ lmtcache[i]->reset(lmtcache[i]->cursize());
+ }
+ }
+#endif
+ }
+
+ void lmtable::check_caches_levels() const
+ {
+#ifdef PS_CACHE_ENABLE
+ check_prob_and_state_cache_levels();
+#endif
+#ifdef LMT_CACHE_ENABLE
+ check_lmtcaches_levels();
+#endif
+ }
+
+ void lmtable::reset_prob_and_state_cache()
+ {
+#ifdef PS_CACHE_ENABLE
+ for (int i=1; i<=max_cache_lev; i++)
+ {
+ if (prob_and_state_cache[i])
+ {
+ prob_and_state_cache[i]->reset(MAX(prob_and_state_cache[i]->cursize(),prob_and_state_cache[i]->maxsize()));
+ }
+ }
+#endif
+ }
+
+ void lmtable::reset_lmtcaches()
+ {
+#ifdef LMT_CACHE_ENABLE
+ for (int i=2; i<=max_cache_lev; i++)
+ {
+ if (lmtcache[i])
+ {
+ lmtcache[i]->reset(MAX(lmtcache[i]->cursize(),lmtcache[i]->maxsize()));
+ }
+ }
+#endif
+ }
+
+ void lmtable::reset_caches()
+ {
+ VERBOSE(2,"void lmtable::reset_caches()" << std::endl);
+#ifdef PS_CACHE_ENABLE
+ reset_prob_and_state_cache();
+#endif
+#ifdef LMT_CACHE_ENABLE
+ reset_lmtcaches();
+#endif
+ }
+
+ bool lmtable::are_prob_and_state_cache_active() const
+ {
+#ifdef PS_CACHE_ENABLE
+ if (max_cache_lev < 1)
+ {
+ return false;
+ }
+ for (int i=1; i<=max_cache_lev; i++)
+ {
+ if (prob_and_state_cache[i]==NULL)
+ {
+ return false;
+ }
+ }
+ return true;
+ // return prob_and_state_cache!=NULL;
+#else
+ return false;
+#endif
+ }
+
+ bool lmtable::are_lmtcaches_active() const
+ {
+#ifdef LMT_CACHE_ENABLE
+ if (max_cache_lev < 2)
+ {
+ return false;
+ }
+ for (int i=2; i<=max_cache_lev; i++)
+ {
+ if (lmtcache[i]==NULL)
+ {
+ return false;
+ }
+ }
+ return true;
+#else
+ return false;
+#endif
+ }
+
+ bool lmtable::are_caches_active() const
+ {
+ return (are_prob_and_state_cache_active() && are_lmtcaches_active());
+ }
+
+ void lmtable::configure(int n,bool quantized)
+ {
+ VERBOSE(2,"void lmtable::configure(int n,bool quantized) with n:" << n << std::endl);
+ maxlev=n;
+ VERBOSE(2," maxlev:" << maxlev << " maxlevel():" << maxlevel() << " this->maxlevel():" << this->maxlevel() << std::endl);
+
+ //The value for index 0 is never used
+ for (int i=0; i<n; i++)
+ {
+ tbltype[i]=(quantized?QINTERNAL:INTERNAL);
+ }
+ tbltype[n]=(quantized?QLEAF:LEAF);
+ }
+
+
+ void lmtable::load(const std::string &infile, int mmap)
+ {
+ VERBOSE(2,"lmtable::load(const std::string &filename, int mmap)" << std::endl);
+ VERBOSE(2,"Reading " << infile << "..." << std::endl);
+ inputfilestream inp(infile.c_str());
+
+ if (!inp.good()) {
+ VERBOSE(2, "Failed to open " << infile << "!" << std::endl);
+ exit_error(IRSTLM_ERROR_IO, "Failed to open "+infile);
+ }
+ setMaxLoadedLevel(requiredMaxlev);
+
+ //check whether memory mapping is required
+ if (infile.compare(infile.size()-3,3,".mm")==0) {
+ mmap=1;
+ }
+
+ if (mmap>0) { //check whether memory mapping can be used
+#ifdef WIN32
+ mmap=0; //don't use memory map
+#endif
+ }
+
+ load(inp,infile.c_str(),NULL,mmap);
+ getDict()->incflag(0);
+ }
+
+ void lmtable::load(istream& inp,const char* filename,const char* outfilename,int keep_on_disk)
+ {
+ VERBOSE(2,"lmtable::load(istream& inp,...)" << std::endl);
+
+#ifdef WIN32
+ if (keep_on_disk>0) {
+ VERBOSE(2, "lmtable::load memory mapping not yet available under WIN32" << std::endl);
+ keep_on_disk = 0;
+ }
+#endif
+
+ //give a look at the header to select loading method
+ char header[MAX_LINE];
+ inp >> header;
+ VERBOSE(2, header << std::endl);
+
+ if (strncmp(header,"Qblmt",5)==0 || strncmp(header,"blmt",4)==0) {
+ loadbin(inp,header,filename,keep_on_disk);
+ } else { //input is in textual form
+
+ if (keep_on_disk && outfilename==NULL) {
+ VERBOSE(2, "Load Error: inconsistent setting. Passed input file: textual. Memory map: yes. Outfilename: not specified." << std::endl);
+ exit(0);
+ }
+
+ loadtxt(inp,header,outfilename,keep_on_disk);
+ }
+
+ VERBOSE(2, "OOV code is " << lmtable::getDict()->oovcode() << std::endl);
+ }
+
+
+ //load language model on demand through a word-list file
+
+ int lmtable::reload(std::set<string> words)
+ {
+ //build dictionary
+ dictionary dict(NULL,(int)words.size());
+ dict.incflag(1);
+
+ std::set<string>::iterator w;
+ for (w = words.begin(); w != words.end(); ++w)
+ dict.encode((*w).c_str());
+
+ return 1;
+ }
+
+
+
+ void lmtable::load_centers(istream& inp,int Order)
+ {
+ char line[MAX_LINE];
+
+ //first read the coodebook
+ VERBOSE(2, Order << " read code book " << std::endl);
+ inp >> NumCenters[Order];
+ Pcenters[Order]=new float[NumCenters[Order]];
+ Bcenters[Order]=(Order<maxlev?new float[NumCenters[Order]]:NULL);
+
+ for (int c=0; c<NumCenters[Order]; c++) {
+ inp >> Pcenters[Order][c];
+ if (Order<maxlev) inp >> Bcenters[Order][c];
+ };
+ //empty the last line
+ inp.getline((char*)line,MAX_LINE);
+ }
+
+ void lmtable::loadtxt(istream& inp,const char* header,const char* outfilename,int mmap)
+ {
+ if (mmap>0)
+ loadtxt_mmap(inp,header,outfilename);
+ else {
+ loadtxt_ram(inp,header);
+ lmtable::getDict()->genoovcode();
+ }
+ }
+
+ void lmtable::loadtxt_mmap(istream& inp,const char* header,const char* outfilename)
+ {
+
+ char nameNgrams[BUFSIZ];
+ char nameHeader[BUFSIZ];
+
+ FILE *fd = NULL;
+ table_pos_t filesize=0;
+
+ int Order,n;
+
+ //char *SepString = " \t\n"; unused
+
+ //open input stream and prepare an input string
+ char line[MAX_LINE];
+
+ //prepare word dictionary
+ //dict=(dictionary*) new dictionary(NULL,1000000,NULL,NULL);
+ lmtable::getDict()->incflag(1);
+
+ //check the header to decide if the LM is quantized or not
+ isQtable=(strncmp(header,"qARPA",5)==0?true:false);
+
+ //check the header to decide if the LM table is incomplete
+ isItable=(strncmp(header,"iARPA",5)==0?true:false);
+
+ if (isQtable) {
+ int maxlevel_h;
+ //check if header contains other infos
+ inp >> line;
+ if (!(maxlevel_h=atoi(line))) {
+ VERBOSE(2, "loadtxt with mmap requires new qARPA header. Please regenerate the file." << std::endl);
+ exit(1);
+ }
+
+ for (n=1; n<=maxlevel_h; n++) {
+ inp >> line;
+ if (!(NumCenters[n]=atoi(line))) {
+ VERBOSE(2, "loadtxt with mmap requires new qARPA header. Please regenerate the file." << std::endl);
+ exit(0);
+ }
+ }
+ }
+
+ //we will configure the table later we we know the maxlev;
+ bool yetconfigured=false;
+
+ VERBOSE(2,"loadtxtmmap()" << std::endl);
+
+ // READ ARPA Header
+
+ while (inp.getline(line,MAX_LINE)) {
+
+ if (strlen(line)==MAX_LINE-1) {
+ VERBOSE(2,"lmtable::loadtxt_mmap: input line exceed MAXLINE (" << MAX_LINE << ") chars " << line << std::endl);
+ exit(1);
+ }
+
+ bool backslash = (line[0] == '\\');
+
+ if (sscanf(line, "ngram %d=%d", &Order, &n) == 2) {
+ maxsize[Order] = n;
+ maxlev=Order; //upadte Order
+ VERBOSE(2,"size[" << Order << "]=" << maxsize[Order] << std::endl);
+ }
+
+ VERBOSE(2,"maxlev" << maxlev << std::endl);
+ if (maxlev>requiredMaxlev) maxlev=requiredMaxlev;
+ VERBOSE(2,"maxlev" << maxlev << std::endl);
+ VERBOSE(2,"lmtable:requiredMaxlev" << requiredMaxlev << std::endl);
+
+ if (backslash && sscanf(line, "\\%d-grams", &Order) == 1) {
+
+ //at this point we are sure about the size of the LM
+ if (!yetconfigured) {
+ configure(maxlev,isQtable);
+ yetconfigured=true;
+
+ //opening output file
+ strcpy(nameNgrams,outfilename);
+ strcat(nameNgrams, "-ngrams");
+
+ fd = fopen(nameNgrams, "w+");
+
+ // compute the size of file (only for tables and - possibly - centroids; no header nor dictionary)
+ for (int l=1; l<=maxlev; l++) {
+ if (l<maxlev)
+ filesize += (table_pos_t) maxsize[l] * nodesize(tbltype[l]) + 2 * NumCenters[l] * sizeof(float);
+ else
+ filesize += (table_pos_t) maxsize[l] * nodesize(tbltype[l]) + NumCenters[l] * sizeof(float);
+ }
+
+ // set the file to the proper size:
+ ftruncate(fileno(fd),filesize);
+ table[0]=(char *)(MMap(fileno(fd),PROT_READ|PROT_WRITE,0,filesize,&tableGaps[0]));
+
+ //allocate space for tables into the file through mmap:
+ /*
+ if (maxlev>1)
+ table[1]=table[0] + (table_pos_t) (2 * NumCenters[1] * sizeof(float));
+ else
+ table[1]=table[0] + (table_pos_t) (NumCenters[1] * sizeof(float));
+ */
+
+ for (int l=1; l<=maxlev; l++) {
+ if (l<maxlev)
+ table[l]=(char *)(table[l-1] + (table_pos_t) maxsize[l-1]*nodesize(tbltype[l-1]) +
+ 2 * NumCenters[l] * sizeof(float));
+ else
+ table[l]=(char *)(table[l-1] + (table_pos_t) maxsize[l-1]*nodesize(tbltype[l-1]) +
+ NumCenters[l] * sizeof(float));
+
+ VERBOSE(2,"table[" << l << "]-table[" << l-1 << "]=" << (table_pos_t) table[l]-(table_pos_t) table[l-1] << " (nodesize=" << nodesize(tbltype[l-1]) << std::endl);
+ }
+ }
+
+ loadtxt_level(inp,Order);
+
+ if (isQtable) {
+ // writing centroids on disk
+ if (Order<maxlev) {
+ memcpy(table[Order] - 2 * NumCenters[Order] * sizeof(float),
+ Pcenters[Order],
+ NumCenters[Order] * sizeof(float));
+ memcpy(table[Order] - NumCenters[Order] * sizeof(float),
+ Bcenters[Order],
+ NumCenters[Order] * sizeof(float));
+ } else {
+ memcpy(table[Order] - NumCenters[Order] * sizeof(float),
+ Pcenters[Order],
+ NumCenters[Order] * sizeof(float));
+ }
+ }
+ // To avoid huge memory write concentrated at the end of the program
+ msync(table[0],filesize,MS_SYNC);
+
+ // now we can fix table at level Order -1
+ // (not required if the input LM is in lexicographical order)
+ if (maxlev>1 && Order>1) {
+ checkbounds(Order-1);
+ delete startpos[Order-1];
+ }
+ }
+ }
+
+ VERBOSE(2,"closing output file: " << nameNgrams << std::endl);
+ for (int i=1; i<=maxlev; i++) {
+ if (maxsize[i] != cursize[i]) {
+ for (int l=1; l<=maxlev; l++)
+ VERBOSE(2,"Level " << l << ": starting ngrams=" << maxsize[l] << " - actual stored ngrams=" << cursize[l] << std::endl);
+ break;
+ }
+ }
+
+ Munmap(table[0],filesize,MS_SYNC);
+ for (int l=1; l<=maxlev; l++)
+ table[l]=0; // to avoid wrong free in ~lmtable()
+ VERBOSE(2,"running fclose..." << std::endl);
+ fclose(fd);
+ VERBOSE(2,"done" << std::endl);
+
+ lmtable::getDict()->incflag(0);
+ lmtable::getDict()->genoovcode();
+
+ // saving header + dictionary
+
+ strcpy(nameHeader,outfilename);
+ strcat(nameHeader, "-header");
+ VERBOSE(2,"saving header+dictionary in " << nameHeader << "\n");
+ fstream out(nameHeader,ios::out);
+
+ // print header
+ if (isQtable) {
+ out << "Qblmt" << (isInverted?"I ":" ") << maxlev;
+ for (int i=1; i<=maxlev; i++) out << " " << maxsize[i]; // not cursize[i] because the file was already allocated
+ out << "\nNumCenters";
+ for (int i=1; i<=maxlev; i++) out << " " << NumCenters[i];
+ out << "\n";
+
+ } else {
+ out << "blmt" << (isInverted?"I ":" ") << maxlev;
+ for (int i=1; i<=maxlev; i++) out << " " << maxsize[i]; // not cursize[i] because the file was already allocated
+ out << "\n";
+ }
+
+ lmtable::getDict()->save(out);
+
+ out.close();
+ VERBOSE(2,"done" << std::endl);
+
+ // cat header+dictionary and n-grams files:
+
+ char cmd[BUFSIZ];
+ sprintf(cmd,"cat %s >> %s", nameNgrams, nameHeader);
+ VERBOSE(2,"run cmd <" << cmd << std::endl);
+ system(cmd);
+
+ sprintf(cmd,"mv %s %s", nameHeader, outfilename);
+ VERBOSE(2,"run cmd <" << cmd << std::endl);
+ system(cmd);
+
+ removefile(nameNgrams);
+
+ //no more operations are available, the file must be saved!
+ exit(0);
+ return;
+ }
+
+
+ void lmtable::loadtxt_ram(istream& inp,const char* header)
+ {
+ //open input stream and prepare an input string
+ char line[MAX_LINE];
+
+ //prepare word dictionary
+ lmtable::getDict()->incflag(1);
+
+ //check the header to decide if the LM is quantized or not
+ isQtable=(strncmp(header,"qARPA",5)==0?true:false);
+
+ //check the header to decide if the LM table is incomplete
+ isItable=(strncmp(header,"iARPA",5)==0?true:false);
+
+ //we will configure the table later when we will know the maxlev;
+ bool yetconfigured=false;
+
+ VERBOSE(2,"loadtxt_ram()" << std::endl);
+
+ // READ ARPA Header
+ int Order;
+ unsigned int n;
+
+ while (inp.getline(line,MAX_LINE)) {
+ if (strlen(line)==MAX_LINE-1) {
+ VERBOSE(2,"lmtable::loadtxt_ram: input line exceed MAXLINE (" << MAX_LINE << ") chars " << line << std::endl);
+ exit(1);
+ }
+
+ bool backslash = (line[0] == '\\');
+
+ if (sscanf(line, "ngram %d=%u", &Order, &n) == 2) {
+ maxsize[Order] = n;
+ maxlev=Order; //update Order
+ }
+
+ if (maxlev>requiredMaxlev) maxlev=requiredMaxlev;
+
+ if (backslash && sscanf(line, "\\%d-grams", &Order) == 1) {
+
+ //at this point we are sure about the size of the LM
+ if (!yetconfigured) {
+ configure(maxlev,isQtable);
+ yetconfigured=true;
+ //allocate space for loading the table of this level
+ for (int i=1; i<=maxlev; i++)
+ table[i] = new char[(table_pos_t) maxsize[i] * nodesize(tbltype[i])];
+ }
+
+ loadtxt_level(inp,Order);
+
+ // now we can fix table at level Order - 1
+ if (maxlev>1 && Order>1) {
+ checkbounds(Order-1);
+ }
+ }
+ }
+
+ lmtable::getDict()->incflag(0);
+ VERBOSE(2,"done" << std::endl);
+ }
+
+ void lmtable::loadtxt_level(istream& inp, int level)
+ {
+ VERBOSE(2, level << "-grams: reading " << std::endl);
+
+ if (isQtable) {
+ load_centers(inp,level);
+ }
+
+ //allocate support vector to manage badly ordered n-grams
+ if (maxlev>1 && level<maxlev) {
+ startpos[level]=new table_entry_pos_t[maxsize[level]];
+ for (table_entry_pos_t c=0; c<maxsize[level]; c++) {
+ startpos[level][c]=BOUND_EMPTY1;
+ }
+ }
+
+ //prepare to read the n-grams entries
+ VERBOSE(2, maxsize[level] << " entries" << std::endl);
+
+ float prob,bow;
+
+ //put here ngrams, log10 probabilities or their codes
+ ngram ng(lmtable::getDict());
+ ngram ing(lmtable::getDict()); //support n-gram
+
+ //WE ASSUME A WELL STRUCTURED FILE!!!
+ for (table_entry_pos_t c=0; c<maxsize[level]; c++) {
+
+ if (parseline(inp,level,ng,prob,bow)) {
+
+ // if table is inverted then revert n-gram
+ if (isInverted && (level>1)) {
+ ing.invert(ng);
+ ng=ing;
+ }
+
+ //if table is in incomplete ARPA format prob is just the
+ //discounted frequency, so we need to add bow * Pr(n-1 gram)
+ if (isItable && (level>1)) {
+ //get bow of lower context
+ get(ng,ng.size,ng.size-1);
+ float rbow=0.0;
+ if (ng.lev==ng.size-1) { //found context
+ rbow=ng.bow;
+ }
+
+ int tmp=maxlev;
+ maxlev=level-1;
+ prob= log(exp((double)prob * M_LN10) + exp(((double)rbow + lprob(ng)) * M_LN10))/M_LN10;
+ maxlev=tmp;
+ }
+
+ //insert an n-gram into the TRIE table
+ if (isQtable) add(ng, (qfloat_t)prob, (qfloat_t)bow);
+ else add(ng, prob, bow);
+ }
+ }
+ VERBOSE(2, "done level " << level << std::endl);
+ }
+
+
+ void lmtable::expand_level(int level, table_entry_pos_t size, const char* outfilename, int mmap)
+ {
+ if (mmap>0)
+ expand_level_mmap(level, size, outfilename);
+ else {
+ expand_level_nommap(level, size);
+ }
+ }
+
+ void lmtable::expand_level_mmap(int level, table_entry_pos_t size, const char* outfilename)
+ {
+ maxsize[level]=size;
+
+ //getting the level-dependent filename
+ char nameNgrams[BUFSIZ];
+ sprintf(nameNgrams,"%s-%dgrams",outfilename,level);
+
+ //opening output file
+ FILE *fd = NULL;
+ fd = fopen(nameNgrams, "w+");
+ if (fd == NULL) {
+ perror("Error opening file for writing");
+ exit_error(IRSTLM_ERROR_IO, "Error opening file for writing");
+ }
+ table_pos_t filesize=(table_pos_t) maxsize[level] * nodesize(tbltype[level]);
+ // set the file to the proper size:
+ ftruncate(fileno(fd),filesize);
+
+ /* Now the file is ready to be mmapped.
+ */
+ table[level]=(char *)(MMap(fileno(fd),PROT_READ|PROT_WRITE,0,filesize,&tableGaps[level]));
+ if (table[level] == MAP_FAILED) {
+ fclose(fd);
+ perror("Error mmapping the file");
+ exit_error(IRSTLM_ERROR_IO, "Error mmapping the file");
+ }
+
+ if (maxlev>1 && level<maxlev) {
+ startpos[level]=new table_entry_pos_t[maxsize[level]];
+ /*
+ LMT_TYPE ndt=tbltype[level];
+ TOCHECK XXXXXXXXX
+ int ndsz=nodesize(ndt);
+ char *found = table[level];
+ */
+ for (table_entry_pos_t c=0; c<maxsize[level]; c++) {
+ startpos[level][c]=BOUND_EMPTY1;
+ /*
+ TOCHECK XXXXXXXXX
+ found += ndsz;
+ bound(found,ndt,BOUND_EMPTY2);
+ */
+ }
+ }
+ }
+
+ void lmtable::expand_level_nommap(int level, table_entry_pos_t size)
+ {
+ VERBOSE(2,"lmtable::expand_level_nommap START level:" << level << " size:" << size << endl);
+ maxsize[level]=size;
+ table[level] = new char[(table_pos_t) maxsize[level] * nodesize(tbltype[level])];
+ if (maxlev>1 && level<maxlev) {
+ startpos[level]=new table_entry_pos_t[maxsize[level]];
+ /*
+ TOCHECK XXXXXXXXX
+ LMT_TYPE ndt=tbltype[level];
+ int ndsz=nodesize(ndt);
+ char *found = table[level];
+ */
+ LMT_TYPE ndt=tbltype[level];
+ int ndsz=nodesize(ndt);
+ char *found = table[level];
+
+ for (table_entry_pos_t c=0; c<maxsize[level]; c++) {
+ startpos[level][c]=BOUND_EMPTY1;
+ /*
+ TOCHECK XXXXXXXXX
+ found += ndsz;
+ bound(found,ndt,BOUND_EMPTY2);
+ */
+ found += ndsz;
+ }
+ }
+ VERBOSE(2,"lmtable::expand_level_nommap END level:" << level << endl);
+ }
+
+ void lmtable::printTable(int level)
+ {
+ char* tbl=table[level];
+ LMT_TYPE ndt=tbltype[level];
+ int ndsz=nodesize(ndt);
+ table_entry_pos_t printEntryN=getCurrentSize(level);
+ // if (cursize[level]>0)
+ // printEntryN=(printEntryN<cursize[level])?printEntryN:cursize[level];
+
+ cout << "level = " << level << " of size:" << printEntryN <<" ndsz:" << ndsz << " \n";
+
+ //TOCHECK: Nicola, 18 dicembre 2009
+
+ if (level<maxlev){
+ float p;
+ float bw;
+ table_entry_pos_t bnd;
+ table_entry_pos_t start;
+ for (table_entry_pos_t c=0; c<printEntryN; c++) {
+ p=prob(tbl,ndt);
+ bw=bow(tbl,ndt);
+ bnd=bound(tbl,ndt);
+ start=startpos[level][c];
+ VERBOSE(2, p << " " << word(tbl) << " -> " << dict->decode(word(tbl)) << " bw:" << bw << " bnd:" << bnd << " " << start << " tb_offset:" << tb_offset[level+1] << std::endl);
+ tbl+=ndsz;
+ }
+ }else{
+ float p;
+ for (table_entry_pos_t c=0; c<printEntryN; c++) {
+ p=prob(tbl,ndt);
+ VERBOSE(2, p << " " << word(tbl) << " -> " << dict->decode(word(tbl)) << std::endl);
+ tbl+=ndsz;
+ }
+ }
+ return;
+ }
+
+ //Checkbound with sorting of n-gram table on disk
+ void lmtable::checkbounds(int level)
+ {
+ VERBOSE(2,"lmtable::checkbounds START Level:" << level << endl);
+
+ if (getCurrentSize(level) > 0 ){
+
+ char* tbl=table[level];
+ char* succtbl=table[level+1];
+
+ LMT_TYPE ndt=tbltype[level];
+ LMT_TYPE succndt=tbltype[level+1];
+ int ndsz=nodesize(ndt);
+ int succndsz=nodesize(succndt);
+
+ //re-order table at level+1 on disk
+ //generate random filename to avoid collisions
+
+ std::string filePath;
+ // ofstream out;
+ mfstream out;
+ createtempfile(out, filePath, ios::out|ios::binary);
+
+ if (out.fail())
+ {
+ perror("checkbound creating out on filePath");
+ exit(4);
+ }
+
+ table_entry_pos_t start,end,newend;
+ table_entry_pos_t succ;
+
+ //re-order table at level l+1
+ char* found;
+ for (table_entry_pos_t c=0; c<cursize[level]; c++) {
+ found=tbl+(table_pos_t) c*ndsz;
+ start=startpos[level][c];
+ end=boundwithoffset(found,ndt,level);
+
+ if (c>0) newend=boundwithoffset(found-ndsz,ndt,level);
+ else newend=0;
+
+ //if start==BOUND_EMPTY1 there are no successors for this entry
+ if (start==BOUND_EMPTY1){
+ succ=0;
+ }
+ else{
+ MY_ASSERT(end>start);
+ succ=end-start;
+ }
+
+ startpos[level][c]=newend;
+ newend += succ;
+
+ MY_ASSERT(newend<=cursize[level+1]);
+
+ if (succ>0) {
+ out.write((char*)(succtbl + (table_pos_t) start * succndsz),(table_pos_t) succ * succndsz);
+ if (!out.good()) {
+ VERBOSE(2," Something went wrong while writing temporary file " << filePath << " Maybe there is not enough space on this filesystem" << endl);
+
+ out.close();
+ exit(2);
+ removefile(filePath);
+ }
+ }
+
+ boundwithoffset(found,ndt,newend,level);
+ }
+ out.close();
+ if (out.fail())
+ {
+ perror("error closing out");
+ exit(4);
+ }
+
+ fstream inp(filePath.c_str(),ios::in|ios::binary);
+ if (inp.fail())
+ {
+ perror("error opening inp");
+ exit(4);
+ }
+
+ inp.read(succtbl,(table_pos_t) cursize[level+1]*succndsz);
+ inp.close();
+ if (inp.fail())
+ {
+ perror("error closing inp");
+ exit(4);
+ }
+
+ removefile(filePath);
+ }
+ VERBOSE(2,"lmtable::checkbounds END Level:" << level << endl);
+ }
+
+ //Add method inserts n-grams in the table structure. It is ONLY used during
+ //loading of LMs in text format. It searches for the prefix, then it adds the
+ //suffix to the last level and updates the start-end positions.
+ int lmtable::addwithoffset(ngram& ng, float iprob, float ibow)
+ {
+ char *found;
+ LMT_TYPE ndt=tbltype[1]; //default initialization
+ int ndsz=nodesize(ndt); //default initialization
+ static int no_more_msg = 0;
+
+ if (ng.size>1) {
+
+ // find the prefix starting from the first level
+ table_entry_pos_t start=0;
+ table_entry_pos_t end=cursize[1];
+ table_entry_pos_t position;
+
+ for (int l=1; l<ng.size; l++) {
+
+ ndt=tbltype[l];
+ ndsz=nodesize(ndt);
+
+ if (search(l,start,(end-start),ndsz, ng.wordp(ng.size-l+1),LMT_FIND, &found)) {
+
+ //update start and end positions for next step
+ if (l < (ng.size-1)) {
+ //set start position
+ if (found==table[l]){
+ start=0; //first pos in table
+ }
+ else {
+ position=(table_entry_pos_t) (((table_pos_t) (found)-(table_pos_t) table[l])/ndsz);
+ start=startpos[l][position];
+ }
+
+ end=boundwithoffset(found,ndt,l);
+ }
+ } else {
+ if (!no_more_msg)
+ {
+ VERBOSE(2, "warning: missing back-off (at level " << l << ") for ngram " << ng << " (and possibly for others)" << std::endl);
+ }
+ no_more_msg++;
+ if (!(no_more_msg % 5000000))
+ {
+ VERBOSE(2, "!" << std::endl);
+ }
+ return 0;
+ }
+ }
+
+ // update book keeping information about level ng-size -1.
+ position=(table_entry_pos_t) (((table_pos_t) found-(table_pos_t) table[ng.size-1])/ndsz);
+
+ // if this is the first successor update start position in the previous level
+ if (startpos[ng.size-1][position]==BOUND_EMPTY1)
+ startpos[ng.size-1][position]=cursize[ng.size];
+
+ //always update ending position
+ boundwithoffset(found,ndt,cursize[ng.size]+1,ng.size-1);
+ }
+
+ // just add at the end of table[ng.size]
+
+ MY_ASSERT(cursize[ng.size]< maxsize[ng.size]); // is there enough space?
+ ndt=tbltype[ng.size];
+ ndsz=nodesize(ndt);
+
+ found=table[ng.size] + ((table_pos_t) cursize[ng.size] * ndsz);
+ word(found,*ng.wordp(1));
+ prob(found,ndt,iprob);
+ if (ng.size<maxlev) {
+ //find the bound of the previous entry
+ table_entry_pos_t newend;
+ if (found==table[ng.size]) newend=0; //first pos in table
+ else newend=boundwithoffset(found - ndsz,ndt,ng.size);
+
+ bow(found,ndt,ibow);
+ boundwithoffset(found,ndt,newend,ng.size);
+ }
+ cursize[ng.size]++;
+
+ if (!(cursize[ng.size]%5000000))
+ {
+ VERBOSE(1, "." << std::endl);
+ }
+ return 1;
+
+ };
+
+
+ //template<typename TA, typename TB>
+ //int lmtable::add(ngram& ng, TA iprob,TB ibow)
+
+ int lmtable::add(ngram& ng, float iprob, float ibow)
+ {
+ char *found;
+ LMT_TYPE ndt=tbltype[1]; //default initialization
+ int ndsz=nodesize(ndt); //default initialization
+ static int no_more_msg = 0;
+
+ if (ng.size>1) {
+
+ // find the prefix starting from the first level
+ table_entry_pos_t start=0;
+ table_entry_pos_t end=cursize[1];
+ table_entry_pos_t position;
+
+ for (int l=1; l<ng.size; l++) {
+
+ ndt=tbltype[l];
+ ndsz=nodesize(ndt);
+
+ if (search(l,start,(end-start),ndsz, ng.wordp(ng.size-l+1),LMT_FIND, &found)) {
+
+ //update start and end positions for next step
+ if (l < (ng.size-1)) {
+ //set start position
+ if (found==table[l]){
+ start=0; //first pos in table
+ }
+ else {
+ position=(table_entry_pos_t) (((table_pos_t) (found)-(table_pos_t) table[l])/ndsz);
+ start=startpos[l][position];
+ }
+
+ end=bound(found,ndt);
+ }
+ }
+ else {
+ if (!no_more_msg)
+ {
+ VERBOSE(2, "warning: missing back-off (at level " << l << ") for ngram " << ng << " (and possibly for others)" << std::endl);
+ }
+ no_more_msg++;
+ if (!(no_more_msg % 5000000))
+ {
+ VERBOSE(2, "!" << std::endl);
+ }
+ return 0;
+ }
+ }
+
+ // update book keeping information about level ng-size -1.
+ position=(table_entry_pos_t) (((table_pos_t) found-(table_pos_t) table[ng.size-1])/ndsz);
+
+ // if this is the first successor update start position in the previous level
+ if (startpos[ng.size-1][position]==BOUND_EMPTY1)
+ startpos[ng.size-1][position]=cursize[ng.size];
+
+ //always update ending position
+ bound(found,ndt,cursize[ng.size]+1);
+ }
+
+ // just add at the end of table[ng.size]
+
+ MY_ASSERT(cursize[ng.size]< maxsize[ng.size]); // is there enough space?
+ ndt=tbltype[ng.size];
+ ndsz=nodesize(ndt);
+
+ found=table[ng.size] + ((table_pos_t) cursize[ng.size] * ndsz);
+ word(found,*ng.wordp(1));
+ prob(found,ndt,iprob);
+ if (ng.size<maxlev) {
+ //find the bound of the previous entry
+ table_entry_pos_t newend;
+ if (found==table[ng.size]) newend=0; //first pos in table
+ else newend=bound(found - ndsz,ndt);
+
+ bow(found,ndt,ibow);
+ bound(found,ndt,newend);
+ }
+
+ cursize[ng.size]++;
+
+ if (!(cursize[ng.size]%5000000))
+ {
+ VERBOSE(1, "." << std::endl);
+ }
+ return 1;
+
+ };
+
+
+ void *lmtable::search(int lev,
+ table_entry_pos_t offs,
+ table_entry_pos_t n,
+ int sz,
+ int *ngp,
+ LMT_ACTION action,
+ char **found)
+ {
+
+ /***
+ if (n >=2)
+ cout << "searching entry for codeword: " << ngp[0] << "...";
+ ***/
+
+ //assume 1-grams is a 1-1 map of the vocabulary
+ //CHECK: explicit cast of n into float because table_pos_t could be unsigned and larger than MAXINT
+ if (lev==1) return *found=(*ngp < (float) n ? table[1] + (table_pos_t)*ngp * sz:NULL);
+
+
+ //prepare table to be searched with mybsearch
+ char* tb;
+ tb=table[lev] + (table_pos_t) offs * sz;
+ //prepare search pattern
+ char w[LMTCODESIZE];
+ putmem(w,ngp[0],0,LMTCODESIZE);
+
+ table_entry_pos_t idx=0; // index returned by mybsearch
+ *found=NULL; //initialize output variable
+
+ totbsearch[lev]++;
+ switch(action) {
+ case LMT_FIND:
+ // if (!tb || !mybsearch(tb,n,sz,(unsigned char *)w,&idx)) return NULL;
+
+ if (!tb || !mybsearch(tb,n,sz,w,&idx)) {
+ return NULL;
+ } else {
+ // return *found=tb + (idx * sz);
+ return *found=tb + ((table_pos_t)idx * sz);
+ }
+ default:
+ error((char*)"lmtable::search: this option is available");
+ };
+ return NULL;
+ }
+
+
+ /* returns idx with the first position in ar with entry >= key */
+
+ int lmtable::mybsearch(char *ar, table_entry_pos_t n, int size, char *key, table_entry_pos_t *idx)
+ {
+ if (n==0) return 0;
+
+ *idx=0;
+ register table_entry_pos_t low=0, high=n;
+ register unsigned char *p;
+ int result;
+
+#ifdef INTERP_SEARCH
+
+ char *lp=NULL;
+ char *hp=NULL;
+
+#endif
+
+ while (low < high) {
+
+#ifdef INTERP_SEARCH
+ //use interpolation search only for intervals with at least 4096 entries
+
+ if ((high-low)>=10000) {
+
+ lp=(char *) (ar + (low * size));
+ if (codecmp((char *)key,lp)<0) {
+ *idx=low;
+ return 0;
+ }
+
+ hp=(char *) (ar + ((high-1) * size));
+ if (codecmp((char *)key,hp)>0) {
+ *idx=high;
+ return 0;
+ }
+
+ *idx= low + ((high-1)-low) * codediff((char *)key,lp)/codediff(hp,(char *)lp);
+ } else
+#endif
+ *idx = (low + high) / 2;
+
+ //after redefining the interval there is no guarantee
+ //that wlp <= wkey <= whigh
+
+ p = (unsigned char *) (ar + (*idx * size));
+ result=codecmp((char *)key,(char *)p);
+
+ if (result < 0)
+ high = *idx;
+
+ else if (result > 0)
+ low = ++(*idx);
+ else
+ return 1;
+ }
+
+ *idx=low;
+
+ return 0;
+
+ }
+
+
+ // generates a LM copy for a smaller dictionary
+
+ void lmtable::cpsublm(lmtable* slmt, dictionary* subdict,bool keepunigr)
+ {
+
+ //keepunigr=false;
+ //let slmt inherit all features of this lmtable
+
+ slmt->configure(maxlev,isQtable);
+ slmt->dict=new dictionary((keepunigr?dict:subdict),false);
+
+ if (isQtable) {
+ for (int i=1; i<=maxlev; i++) {
+ slmt->NumCenters[i]=NumCenters[i];
+ slmt->Pcenters[i]=new float [NumCenters[i]];
+ memcpy(slmt->Pcenters[i],Pcenters[i],NumCenters[i] * sizeof(float));
+
+ if (i<maxlev) {
+ slmt->Bcenters[i]=new float [NumCenters[i]];
+ memcpy(slmt->Bcenters[i],Bcenters[i],NumCenters[i] * sizeof(float));
+ }
+ }
+ }
+
+ //manage dictionary information
+
+ //generate OOV codes and build dictionary lookup table
+ dict->genoovcode();
+ slmt->dict->genoovcode();
+ subdict->genoovcode();
+
+ int* lookup=new int [dict->size()];
+
+ for (int c=0; c<dict->size(); c++) {
+ lookup[c]=subdict->encode(dict->decode(c));
+ if (c != dict->oovcode() && lookup[c] == subdict->oovcode())
+ lookup[c]=-1; // words of this->dict that are not in slmt->dict
+ }
+
+ //variables useful to navigate in the lmtable structure
+ LMT_TYPE ndt,pndt;
+ int ndsz,pndsz;
+ char *entry, *newentry;
+ table_entry_pos_t start, end, origin;
+
+ for (int l=1; l<=maxlev; l++) {
+
+ slmt->cursize[l]=0;
+ slmt->table[l]=NULL;
+
+ if (l==1) { //1-gram level
+
+ ndt=tbltype[l];
+ ndsz=nodesize(ndt);
+
+ for (table_entry_pos_t p=0; p<cursize[l]; p++) {
+
+ entry=table[l] + (table_pos_t) p * ndsz;
+ if (lookup[word(entry)]!=-1 || keepunigr) {
+
+ if ((slmt->cursize[l] % slmt->dict->size()) ==0)
+ slmt->table[l]=(char *)reallocf(slmt->table[l],((table_pos_t) slmt->cursize[l] + (table_pos_t) slmt->dict->size()) * ndsz);
+
+ newentry=slmt->table[l] + (table_pos_t) slmt->cursize[l] * ndsz;
+ memcpy(newentry,entry,ndsz);
+ if (!keepunigr) //do not change encoding if keepunigr is true
+ slmt->word(newentry,lookup[word(entry)]);
+
+ if (l<maxlev)
+ slmt->bound(newentry,ndt,p); //store in bound the entry itself (**) !!!!
+ slmt->cursize[l]++;
+ }
+ }
+ }
+
+ else { //n-grams n>1: scan lower order table
+
+ pndt=tbltype[l-1];
+ pndsz=nodesize(pndt);
+ ndt=tbltype[l];
+ ndsz=nodesize(ndt);
+
+ for (table_entry_pos_t p=0; p<slmt->cursize[l-1]; p++) {
+
+ //determine start and end of successors of this entry
+ origin=slmt->bound(slmt->table[l-1] + (table_pos_t)p * pndsz,pndt); //position of n-1 gram in this table (**)
+ if (origin == 0) start=0; //succ start at first pos in table[l]
+ else start=bound(table[l-1] + (table_pos_t)(origin-1) * pndsz,pndt);//succ start after end of previous entry
+ end=bound(table[l-1] + (table_pos_t)origin * pndsz,pndt); //succ end where indicated
+
+ if (!keepunigr || lookup[word(table[l-1] + (table_pos_t)origin * pndsz)]!=-1) {
+ while (start < end) {
+
+ entry=table[l] + (table_pos_t) start * ndsz;
+
+ if (lookup[word(entry)]!=-1) {
+
+ if ((slmt->cursize[l] % slmt->dict->size()) ==0)
+ slmt->table[l]=(char *)reallocf(slmt->table[l],(table_pos_t) (slmt->cursize[l]+slmt->dict->size()) * ndsz);
+
+ newentry=slmt->table[l] + (table_pos_t) slmt->cursize[l] * ndsz;
+ memcpy(newentry,entry,ndsz);
+ if (!keepunigr) //do not change encoding if keepunigr is true
+ slmt->word(newentry,lookup[word(entry)]);
+
+ if (l<maxlev)
+ slmt->bound(newentry,ndt,start); //store in bound the entry itself!!!!
+ slmt->cursize[l]++;
+ }
+ start++;
+ }
+ }
+
+ //updated bound information of incoming entry
+ slmt->bound(slmt->table[l-1] + (table_pos_t) p * pndsz, pndt,slmt->cursize[l]);
+ }
+ }
+ }
+
+ return;
+ }
+
+
+
+ // saves a LM table in text format
+
+ void lmtable::savetxt(const char *filename)
+ {
+
+ fstream out(filename,ios::out);
+ table_entry_pos_t cnt[1+MAX_NGRAM];
+ int l;
+
+ // out.precision(7);
+ out.precision(6);
+
+ if (isQtable) {
+ out << "qARPA " << maxlev;
+ for (l=1; l<=maxlev; l++)
+ out << " " << NumCenters[l];
+ out << endl;
+ }
+
+ ngram ng(lmtable::getDict(),0);
+
+ VERBOSE(2, "savetxt: " << filename << std::endl);
+
+ if (isPruned) ngcnt(cnt); //check size of table by considering pruned n-grams
+
+ out << "\n\\data\\\n";
+ char buff[100];
+ for (l=1; l<=maxlev; l++) {
+ sprintf(buff,"ngram %2d=%10d\n",l,(isPruned?cnt[l]:cursize[l]));
+ out << buff;
+ }
+ out << "\n";
+
+ for (l=1; l<=maxlev; l++) {
+
+ out << "\n\\" << l << "-grams:\n";
+ VERBOSE(2, "save: " << (isPruned?cnt[l]:cursize[l]) << " " << l << "-grams" << std::endl);
+ if (isQtable) {
+ out << NumCenters[l] << "\n";
+ for (int c=0; c<NumCenters[l]; c++) {
+ out << Pcenters[l][c];
+ if (l<maxlev) out << " " << Bcenters[l][c];
+ out << "\n";
+ }
+ }
+
+ ng.size=0;
+ dumplm(out,ng,1,l,0,cursize[1]);
+
+ }
+
+ out << "\\end\\\n";
+ VERBOSE(2, "done" << std::endl);
+ }
+
+
+
+ void lmtable::savebin(const char *filename)
+ {
+ VERBOSE(2,"lmtable::savebin START " << filename << "\n");
+
+ if (isPruned) {
+ VERBOSE(2,"lmtable::savebin: pruned LM cannot be saved in binary form\n");
+ exit(0);
+ }
+
+
+ fstream out(filename,ios::out);
+
+ // print header
+ if (isQtable) {
+ out << "Qblmt" << (isInverted?"I":"") << " " << maxlev;
+ for (int i=1; i<=maxlev; i++) out << " " << cursize[i];
+ out << "\nNumCenters";
+ for (int i=1; i<=maxlev; i++) out << " " << NumCenters[i];
+ out << "\n";
+
+ } else {
+ out << "blmt" << (isInverted?"I":"") << " " << maxlev;
+ char buff[100];
+ for (int i=1; i<=maxlev; i++){
+ sprintf(buff," %10d",cursize[i]);
+ out << buff;
+ }
+ out << "\n";
+ }
+
+ lmtable::getDict()->save(out);
+
+ for (int i=1; i<=maxlev; i++) {
+ if (isQtable) {
+ out.write((char*)Pcenters[i],NumCenters[i] * sizeof(float));
+ if (i<maxlev)
+ out.write((char *)Bcenters[i],NumCenters[i] * sizeof(float));
+ }
+ out.write(table[i],(table_pos_t) cursize[i]*nodesize(tbltype[i]));
+ }
+
+ VERBOSE(2,"lmtable::savebin: END\n");
+ }
+
+ void lmtable::savebin_dict(std::fstream& out)
+ {
+ /*
+ if (isPruned)
+ {
+ VERBOSE(2,"savebin_dict: pruned LM cannot be saved in binary form\n");
+ exit(0);
+ }
+ */
+
+ VERBOSE(2,"savebin_dict ...\n");
+ getDict()->save(out);
+ }
+
+
+
+ void lmtable::appendbin_level(int level, fstream &out, int mmap)
+ {
+ if (getCurrentSize(level) > 0 ){
+ if (mmap>0)
+ appendbin_level_mmap(level, out);
+ else {
+ appendbin_level_nommap(level, out);
+ }
+ }
+ }
+
+ void lmtable::appendbin_level_nommap(int level, fstream &out)
+ {
+ VERBOSE(2,"lmtable:appendbin_level_nommap START Level:" << level << std::endl);
+
+ /*
+ if (isPruned){
+ VERBOSE(2,"savebin_level (level " << level << "): pruned LM cannot be saved in binary form" << std::endl);
+ exit(0);
+ }
+ */
+
+ MY_ASSERT(level<=maxlev);
+
+ // print header
+ if (isQtable) {
+ //NOT IMPLEMENTED
+ } else {
+ //do nothing
+ }
+
+ VERBOSE(3,"appending " << cursize[level] << " (maxsize:" << maxsize[level] << ") " << level << "-grams" << " table " << (void*) table << " table[level] " << (void*) table[level] << endl);
+
+ if (isQtable) {
+ //NOT IMPLEMENTED
+ }
+
+ out.write(table[level],(table_pos_t) cursize[level]*nodesize(tbltype[level]));
+
+ if (!out.good()) {
+ perror("Something went wrong while writing");
+ out.close();
+ exit(2);
+ }
+
+ VERBOSE(2,"lmtable:appendbin_level_nommap END Level:" << level << std::endl);
+ }
+
+
+ void lmtable::appendbin_level_mmap(int level, fstream &out)
+ {
+ UNUSED(out);
+ VERBOSE(2,"appending " << level << " (Actually do nothing)" << std::endl);
+ }
+
+ void lmtable::savebin_level(int level, const char* outfilename, int mmap)
+ {
+ if (mmap>0)
+ savebin_level_mmap(level, outfilename);
+ else {
+ savebin_level_nommap(level, outfilename);
+ }
+ }
+
+ void lmtable::savebin_level_nommap(int level, const char* outfilename)
+ {
+ VERBOSE(2,"lmtable:savebin_level_nommap START" << requiredMaxlev << std::endl);
+
+ /*
+ if (isPruned){
+ cerr << "savebin_level (level " << level << "): pruned LM cannot be saved in binary form\n";
+ exit(0);
+ }
+ */
+
+ MY_ASSERT(level<=maxlev);
+
+ char nameNgrams[BUFSIZ];
+ sprintf(nameNgrams,"%s-%dgrams",outfilename,level);
+
+ fstream out(nameNgrams, ios::out|ios::binary);
+
+ if (out.fail())
+ {
+ perror("cannot be opened");
+ exit(3);
+ }
+
+ // print header
+ if (isQtable) {
+ //NOT IMPLEMENTED
+ } else {
+ //do nothing
+ }
+
+ VERBOSE(3,"saving " << cursize[level] << "(maxsize:" << maxsize[level] << ") " << level << "-grams in " << nameNgrams << " table " << (void*) table << " table[level] " << (void*) table[level] << endl);
+ if (isQtable) {
+ //NOT IMPLEMENTED
+ }
+
+ out.write(table[level],(table_pos_t) cursize[level]*nodesize(tbltype[level]));
+
+ if (!out.good()) {
+ VERBOSE(2," Something went wrong while writing temporary file " << nameNgrams << endl);
+ out.close();
+ removefile(nameNgrams);
+ exit(2);
+ }
+ out.close();
+ if (out.fail())
+ {
+ perror("cannot be closed");
+ exit(3);
+ }
+
+ VERBOSE(2,"lmtable:savebin_level_nommap END" << requiredMaxlev << std::endl);
+ }
+
+ void lmtable::savebin_level_mmap(int level, const char* outfilename)
+ {
+ char nameNgrams[BUFSIZ];
+ sprintf(nameNgrams,"%s-%dgrams",outfilename,level);
+ VERBOSE(2,"saving " << level << "-grams probs in " << nameNgrams << " (Actually do nothing)" <<std::endl);
+ }
+
+
+
+ void lmtable::print_table_stat()
+ {
+ VERBOSE(2,"printing statistics of tables" << endl);
+ for (int i=1; i<=maxlev; i++)
+ print_table_stat(i);
+ }
+
+ void lmtable::print_table_stat(int level)
+ {
+ VERBOSE(2," level: " << level);
+ VERBOSE(2," maxsize[level]:" << maxsize[level]);
+ VERBOSE(2," cursize[level]:" << cursize[level]);
+ VERBOSE(2," tb_offset[level]:" << tb_offset[level]);
+ VERBOSE(2," table:" << (void*) table);
+ VERBOSE(2," table[level]:" << (void*) table[level]);
+ VERBOSE(2," tableGaps[level]:" << (void*) tableGaps[level] << std::endl);
+ }
+
+ //concatenate corresponding single level files of two different tables for each level
+ void lmtable::concatenate_all_levels(const char* fromfilename, const char* tofilename){
+ //single level files should have a name derived from "filename"
+ //there no control that the tables have the same size
+ for (int i=1; i<=maxlevel(); i++) {
+ concatenate_single_level(i, fromfilename, tofilename);
+ }
+ }
+
+ //concatenate corresponding single level files of two different tables
+ void lmtable::concatenate_single_level(int level, const char* fromfilename, const char* tofilename){
+ //single level files should have a name derived from "fromfilename" and "tofilename"
+ char fromnameNgrams[BUFSIZ];
+ char tonameNgrams[BUFSIZ];
+ sprintf(fromnameNgrams,"%s-%dgrams",fromfilename,level);
+ sprintf(tonameNgrams,"%s-%dgrams",tofilename,level);
+
+ VERBOSE(2,"concatenating " << level << "-grams probs from " << fromnameNgrams << " to " << tonameNgrams<< std::endl);
+
+
+ //concatenating of new table to the existing data
+ char cmd[BUFSIZ];
+ sprintf(cmd,"cat %s >> %s", fromnameNgrams, tonameNgrams);
+ system(cmd);
+ }
+
+ //remove all single level files
+ void lmtable::remove_all_levels(const char* filename){
+ //single level files should have a name derived from "filename"
+ for (int i=1; i<=maxlevel(); i++) {
+ remove_single_level(i,filename);
+ }
+ }
+
+ //remove a single level file
+ void lmtable::remove_single_level(int level, const char* filename){
+ //single level files should have a name derived from "filename"
+ char nameNgrams[BUFSIZ];
+ sprintf(nameNgrams,"%s-%dgrams",filename,level);
+
+ //removing temporary files
+ removefile(nameNgrams);
+ }
+
+
+
+ //delete the table of a single level
+ void lmtable::delete_level(int level, const char* outfilename, int mmap){
+ if (mmap>0)
+ delete_level_mmap(level, outfilename);
+ else {
+ delete_level_nommap(level);
+ }
+ }
+
+ void lmtable::delete_level_mmap(int level, const char* outfilename)
+ {
+ //getting the level-dependent filename
+ char nameNgrams[BUFSIZ];
+ sprintf(nameNgrams,"%s-%dgrams",outfilename,level);
+
+ //compute exact filesize
+ table_pos_t filesize=(table_pos_t) cursize[level] * nodesize(tbltype[level]);
+
+ // set the file to the proper size:
+ Munmap(table[level]-tableGaps[level],(table_pos_t) filesize+tableGaps[level],0);
+
+ maxsize[level]=cursize[level]=0;
+ }
+
+ void lmtable::delete_level_nommap(int level)
+ {
+ delete table[level];
+ maxsize[level]=cursize[level]=0;
+ }
+
+ void lmtable::compact_all_levels(const char* filename){
+ //single level files should have a name derived from "filename"
+ for (int i=1; i<=maxlevel(); i++) {
+ compact_single_level(i,filename);
+ }
+ }
+
+ void lmtable::compact_single_level(int level, const char* filename)
+ {
+ char nameNgrams[BUFSIZ];
+ sprintf(nameNgrams,"%s-%dgrams",filename,level);
+
+ VERBOSE(2,"concatenating " << level << "-grams probs from " << nameNgrams << " to " << filename<< std::endl);
+
+
+ //concatenating of new table to the existing data
+ char cmd[BUFSIZ];
+ sprintf(cmd,"cat %s >> %s", nameNgrams, filename);
+ system(cmd);
+
+ //removing temporary files
+ removefile(nameNgrams);
+ }
+
+ void lmtable::resize_level(int level, const char* outfilename, int mmap)
+ {
+ if (getCurrentSize(level) > 0 ){
+ if (mmap>0)
+ resize_level_mmap(level, outfilename);
+ else {
+ if (level<maxlev) // (apart from last level maxlev, because is useless), resizing is done when saving
+ resize_level_nommap(level);
+ }
+ }
+ }
+
+ void lmtable::resize_level_mmap(int level, const char* outfilename)
+ {
+ //getting the level-dependent filename
+ char nameNgrams[BUFSIZ];
+ sprintf(nameNgrams,"%s-%dgrams",outfilename,level);
+
+ //recompute exact filesize
+ table_pos_t filesize=(table_pos_t) cursize[level] * nodesize(tbltype[level]);
+
+ //opening output file
+ FILE *fd = NULL;
+ fd = fopen(nameNgrams, "r+");
+
+ // set the file to the proper size:
+ Munmap(table[level]-tableGaps[level],(table_pos_t) filesize+tableGaps[level],0);
+ ftruncate(fileno(fd),filesize);
+ table[level]=(char *)(MMap(fileno(fd),PROT_READ|PROT_WRITE,0,filesize,&tableGaps[level]));
+ maxsize[level]=cursize[level];
+ }
+
+ void lmtable::resize_level_nommap(int level)
+ {
+ VERBOSE(2,"lmtable::resize_level_nommap START Level " << level << "\n");
+
+ //recompute exact filesize
+ table_pos_t filesize=(table_pos_t) cursize[level] * nodesize(tbltype[level]);
+
+ char* ptr = new char[filesize];
+ memcpy(ptr,table[level],filesize);
+ delete table[level];
+ table[level]=ptr;
+ maxsize[level]=cursize[level];
+
+ VERBOSE(2,"lmtable::resize_level_nommap END Level " << level << "\n");
+ }
+
+
+ //manages the long header of a bin file
+ //and allocates table for each n-gram level
+
+ void lmtable::loadbin_header(istream& inp,const char* header)
+ {
+
+ // read rest of header
+ inp >> maxlev;
+
+ //set the inverted falg to false, in order to rely on the header only
+ isInverted=false;
+
+ if (strncmp(header,"Qblmt",5)==0) {
+ isQtable=true;
+ if (strncmp(header,"QblmtI",6)==0)
+ isInverted=true;
+ } else if(strncmp(header,"blmt",4)==0) {
+ isQtable=false;
+ if (strncmp(header,"blmtI",5)==0)
+ isInverted=true;
+ } else error((char*)"loadbin: LM file is not in binary format");
+
+ configure(maxlev,isQtable);
+
+ for (int l=1; l<=maxlev; l++) {
+ inp >> cursize[l];
+ maxsize[l]=cursize[l];
+ }
+
+ char header2[MAX_LINE];
+ if (isQtable) {
+ inp >> header2;
+ for (int i=1; i<=maxlev; i++) {
+ inp >> NumCenters[i];
+ VERBOSE(2,"reading " << NumCenters[i] << " centers" << "\n");
+ }
+ }
+ inp.getline(header2, MAX_LINE);
+ }
+
+ //load codebook of level l
+ void lmtable::loadbin_codebook(istream& inp,int l)
+ {
+ Pcenters[l]=new float [NumCenters[l]];
+ inp.read((char*)Pcenters[l],NumCenters[l] * sizeof(float));
+ if (l<maxlev) {
+ Bcenters[l]=new float [NumCenters[l]];
+ inp.read((char *)Bcenters[l],NumCenters[l]*sizeof(float));
+ }
+ }
+
+
+ //load a binary lmfile
+
+ void lmtable::loadbin(istream& inp, const char* header, const char* filename,int mmap)
+ {
+ VERBOSE(2,"loadbin()" << "\n");
+ loadbin_header(inp,header);
+ loadbin_dict(inp);
+
+ VERBOSE(3,"lmtable::maxlev" << maxlev << std::endl);
+ if (maxlev>requiredMaxlev) maxlev=requiredMaxlev;
+ VERBOSE(3,"lmtable::maxlev:" << maxlev << std::endl);
+ VERBOSE(3,"lmtable::requiredMaxlev" << requiredMaxlev << std::endl);
+
+ //if MMAP is used, then open the file
+ if (filename && mmap>0) {
+
+#ifdef WIN32
+ error("lmtable::loadbin mmap facility not yet supported under WIN32\n");
+#else
+
+ if (mmap <= maxlev) memmap=mmap;
+ else error((char*)"keep_on_disk value is out of range\n");
+
+ if ((diskid=open(filename, O_RDONLY))<0) {
+ VERBOSE(2,"cannot open " << filename << std::endl);
+ error((char*)"dying");
+ }
+
+ //check that the LM is uncompressed
+ char miniheader[4];
+ read(diskid,miniheader,4);
+ if (strncmp(miniheader,"Qblm",4) && strncmp(miniheader,"blmt",4))
+ error((char*)"mmap functionality does not work with compressed binary LMs\n");
+#endif
+ }
+
+ for (int l=1; l<=maxlev; l++) {
+ loadbin_level(inp,l);
+ }
+ VERBOSE(2,"done" << std::endl);
+ }
+
+
+ //load only the dictionary of a binary lmfile
+ void lmtable::loadbin_dict(istream& inp)
+ {
+ VERBOSE(2,"lmtable::loadbin_dict()" << std::endl);
+ lmtable::getDict()->load(inp);
+ VERBOSE(2,"dict->size(): " << lmtable::getDict()->size() << std::endl);
+ }
+
+ //load ONE level of a binary lmfile
+ void lmtable::loadbin_level(istream& inp, int level)
+ {
+ VERBOSE(2,"loadbin_level (level " << level << std::endl);
+
+ if (isQtable)
+ {
+ loadbin_codebook(inp,level);
+ }
+ if ((memmap == 0) || (level < memmap))
+ {
+ VERBOSE(2,"loading " << cursize[level] << " " << level << "-grams" << std::endl);
+ table[level]=new char[(table_pos_t) cursize[level] * nodesize(tbltype[level])];
+ inp.read(table[level],(table_pos_t) cursize[level] * nodesize(tbltype[level]));
+ } else {
+
+#ifdef WIN32
+ error((char*)"mmap not available under WIN32\n");
+#else
+ VERBOSE(2,"mapping " << cursize[level] << " " << level << "-grams" << std::endl);
+ tableOffs[level]=inp.tellg();
+ table[level]=(char *)MMap(diskid,PROT_READ,
+ tableOffs[level], (table_pos_t) cursize[level]*nodesize(tbltype[level]),
+ &tableGaps[level]);
+ table[level]+=(table_pos_t) tableGaps[level];
+ VERBOSE(2,"tableOffs " << tableOffs[level] << " tableGaps" << tableGaps[level] << "-grams" << std::endl);
+ inp.seekg((table_pos_t) cursize[level]*nodesize(tbltype[level]),ios_base::cur);
+#endif
+ }
+ VERBOSE(2,"done (level " << level << std::endl);
+ }
+
+ int lmtable::get(ngram& ng,int n,int lev)
+ {
+ totget[lev]++;
+
+ if (lev > maxlev) error((char*)"get: lev exceeds maxlevel");
+ if (n < lev) error((char*)"get: ngram is too small");
+
+ //set boudaries for 1-gram
+ table_entry_pos_t offset=0,limit=cursize[1];
+
+ //information of table entries
+ char* found;
+ LMT_TYPE ndt;
+ ng.link=NULL;
+ ng.lev=0;
+
+ for (int l=1; l<=lev; l++) {
+
+ //initialize entry information
+ found = NULL;
+ ndt=tbltype[l];
+
+#ifdef LMT_CACHE_ENABLE
+ bool hit = false;
+ if (lmtcache[l] && lmtcache[l]->get(ng.wordp(n),found)) {
+ hit=true;
+ } else {
+ search(l,
+ offset,
+ (limit-offset),
+ nodesize(ndt),
+ ng.wordp(n-l+1),
+ LMT_FIND,
+ &found);
+ }
+
+
+
+ //insert both found and not found items!!!
+// if (lmtcache[l] && hit==true) {
+
+ //insert only not found items!!!
+ if (lmtcache[l] && hit==false) {
+ const char* found2=found;
+ lmtcache[l]->add(ng.wordp(n),found2);
+ }
+#else
+ search(l,
+ offset,
+ (limit-offset),
+ nodesize(ndt),
+ ng.wordp(n-l+1),
+ LMT_FIND,
+ &found);
+#endif
+
+ if (!found) return 0;
+
+ float pr = prob(found,ndt);
+ if (pr==NOPROB) return 0; //pruned n-gram
+
+ ng.path[l]=found; //store path of found entries
+ ng.bow=(l<maxlev?bow(found,ndt):0);
+ ng.prob=pr;
+ ng.link=found;
+ ng.info=ndt;
+ ng.lev=l;
+
+ if (l<maxlev) { //set start/end point for next search
+
+ //if current offset is at the bottom also that of successors will be
+ if (offset+1==cursize[l]) limit=cursize[l+1];
+ else limit=bound(found,ndt);
+
+ //if current start is at the begin, then also that of successors will be
+ if (found==table[l]) offset=0;
+ else offset=bound((found - nodesize(ndt)),ndt);
+
+ MY_ASSERT(offset!=BOUND_EMPTY1);
+ MY_ASSERT(limit!=BOUND_EMPTY1);
+ }
+ }
+
+
+ //put information inside ng
+ ng.size=n;
+ ng.freq=0;
+ ng.succ=(lev<maxlev?limit-offset:0);
+
+#ifdef TRACE_CACHELM
+ if (ng.size==maxlev && sentence_id>0) {
+ *cacheout << sentence_id << " miss " << ng << " " << ng.link << "\n";
+ }
+#endif
+ return 1;
+ }
+
+
+ //recursively prints the language model table
+
+ void lmtable::dumplm(fstream& out,ngram ng, int ilev, int elev, table_entry_pos_t ipos,table_entry_pos_t epos)
+ {
+
+ LMT_TYPE ndt=tbltype[ilev];
+ ngram ing(ng.dict);
+ int ndsz=nodesize(ndt);
+
+ MY_ASSERT(ng.size==ilev-1);
+
+ //Note that ipos and epos are always larger than or equal to 0 because they are unsigned int
+ MY_ASSERT(epos<=cursize[ilev]);
+ MY_ASSERT(ipos<epos);
+ ng.pushc(0);
+
+ for (table_entry_pos_t i=ipos; i<epos; i++) {
+ char* found=table[ilev]+ (table_pos_t) i * ndsz;
+ *ng.wordp(1)=word(found);
+
+ float ipr=prob(found,ndt);
+
+ //skip pruned n-grams
+ if(isPruned && ipr==NOPROB) continue;
+
+ if (ilev<elev) {
+ //get first and last successor position
+ table_entry_pos_t isucc=(i>0?bound(table[ilev]+ (table_pos_t) (i-1) * ndsz,ndt):0);
+ table_entry_pos_t esucc=bound(found,ndt);
+
+ if (isucc < esucc) //there are successors!
+ dumplm(out,ng,ilev+1,elev,isucc,esucc);
+ } else {
+ out << ipr <<"\t";
+
+ // if table is inverted then revert n-gram
+ if (isInverted && (ng.size>1)) {
+ ing.invert(ng);
+ for (int k=ing.size; k>=1; k--) {
+ if (k<ing.size) out << " ";
+ out << lmtable::getDict()->decode(*ing.wordp(k));
+ }
+ } else {
+ for (int k=ng.size; k>=1; k--) {
+ if (k<ng.size) out << " ";
+ out << lmtable::getDict()->decode(*ng.wordp(k));
+ }
+ }
+
+ if (ilev<maxlev) {
+ float ibo=bow(table[ilev]+ (table_pos_t)i * ndsz,ndt);
+ if (isQtable){
+ out << "\t" << ibo;
+ }
+ else{
+ if ((ibo>UPPER_SINGLE_PRECISION_OF_0 || ibo<-UPPER_SINGLE_PRECISION_OF_0)) out << "\t" << ibo;
+ }
+ }
+ out << "\n";
+ }
+ }
+ }
+
+ //succscan iteratively returns all successors of an ngram h for which
+ //get(h,h.size,h.size) returned true.
+
+ int lmtable::succscan(ngram& h,ngram& ng,LMT_ACTION action,int lev)
+ {
+ MY_ASSERT(lev==h.lev+1 && h.size==lev && lev<=maxlev);
+
+ LMT_TYPE ndt=tbltype[h.lev];
+ int ndsz=nodesize(ndt);
+
+ table_entry_pos_t offset;
+ switch (action) {
+
+ case LMT_INIT:
+ //reset ngram local indexes
+
+ ng.size=lev;
+ ng.trans(h);
+ //get number of successors of h
+ ng.midx[lev]=0;
+ offset=(h.link>table[h.lev]?bound(h.link-ndsz,ndt):0);
+ h.succ=bound(h.link,ndt)-offset;
+ h.succlink=table[lev]+(table_pos_t) offset * nodesize(tbltype[lev]);
+ return 1;
+
+ case LMT_CONT:
+ if (ng.midx[lev] < h.succ) {
+ //put current word into ng
+ *ng.wordp(1)=word(h.succlink+(table_pos_t) ng.midx[lev]*nodesize(tbltype[lev]));
+ ng.midx[lev]++;
+ return 1;
+ } else
+ return 0;
+
+ default:
+ exit_error(IRSTLM_ERROR_MODEL, "succscan: only permitted options are LMT_INIT and LMT_CONT");
+ }
+ return 0;
+ }
+
+ //maxsuffptr returns the largest suffix of an n-gram that is contained
+ //in the LM table. This can be used as a compact representation of the
+ //(n-1)-gram state of a n-gram LM. if the input k-gram has k>=n then it
+ //is trimmed to its n-1 suffix.
+
+ //non recursive version
+ const char *lmtable::maxsuffptr(ngram ong, unsigned int* size)
+ {
+ VERBOSE(3,"const char *lmtable::maxsuffptr(ngram ong, unsigned int* size)\n");
+
+ if (ong.size==0) {
+ if (size!=NULL) *size=0;
+ return (char*) NULL;
+ }
+
+
+ if (isInverted) {
+ if (ong.size>maxlev) ong.size=maxlev; //if larger than maxlen reduce size
+ ngram ing=ong; //inverted ngram
+
+ ing.invert(ong);
+
+ get(ing,ing.size,ing.size); // dig in the trie
+ if (ing.lev > 0) { //found something?
+ unsigned int isize = MIN(ing.lev,(ing.size-1)); //find largest n-1 gram suffix
+ if (size!=NULL) *size=isize;
+ return ing.path[isize];
+ } else { // means a real unknown word!
+ if (size!=NULL) *size=0; //default statesize for zero-gram!
+ return NULL; //default stateptr for zero-gram!
+ }
+ } else {
+ if (ong.size>0) ong.size--; //always reduced by 1 word
+
+ if (ong.size>=maxlev) ong.size=maxlev-1; //if still larger or equals to maxlen reduce again
+
+ if (size!=NULL) *size=ong.size; //will return the largest found ong.size
+ for (ngram ng=ong; ng.size>0; ng.size--) {
+ if (get(ng,ng.size,ng.size)) {
+ // if (ng.succ==0) (*size)--;
+ // if (size!=NULL) *size=ng.size;
+ if (size!=NULL)
+ {
+ if (ng.succ==0) *size=ng.size-1;
+ else *size=ng.size;
+ }
+ return ng.link;
+ }
+ }
+ if (size!=NULL) *size=0;
+ return NULL;
+ }
+ }
+
+
+ const char *lmtable::cmaxsuffptr(ngram ong, unsigned int* size)
+ {
+ VERBOSE(3,"const char *lmtable::maxsuffptr(ngram ong, unsigned int* size) ong:|" << ong << "|\n");
+
+ if (ong.size==0) {
+ if (size!=NULL) *size=0;
+ return (char*) NULL;
+ }
+
+ if (size!=NULL) *size=ong.size; //will return the largest found ong.size
+
+#ifdef PS_CACHE_ENABLE
+ prob_and_state_t pst;
+
+ size_t orisize=ong.size;
+ if (ong.size>=maxlev) ong.size=maxlev;
+
+ //cache hit
+ // if (prob_and_state_cache && ong.size==maxlev && prob_and_state_cache->get(ong.wordp(maxlev),pst)) {
+ if (prob_and_state_cache[ong.size] && prob_and_state_cache[ong.size]->get(ong.wordp(ong.size),pst)) {
+ *size=pst.statesize;
+ return pst.state;
+ }
+ ong.size = orisize;
+
+ //cache miss
+ unsigned int isize; //internal state size variable
+ char* found=(char *)maxsuffptr(ong,&isize);
+
+ //cache insert
+ //IMPORTANT: this function updates only two fields (state, statesize) of the entry of the cache; the reminaing fields (logpr, bow, bol, extendible) are undefined; hence, it should not be used before the corresponding clprob()
+
+ if (ong.size>=maxlev) ong.size=maxlev;
+ // if (prob_and_state_cache && ong.size==maxlev) {
+ if (prob_and_state_cache[ong.size]) {
+ pst.state=found;
+ pst.statesize=isize;
+ // prob_and_state_cache->add(ong.wordp(maxlev),pst);
+ prob_and_state_cache[ong.size]->add(ong.wordp(ong.size),pst);
+ }
+ if (size!=NULL) *size=isize;
+ return found;
+#else
+ return (char *)maxsuffptr(ong,size);
+#endif
+ }
+
+
+ //this function simulates the cmaxsuffptr(ngram, ...) but it takes as input an array of codes instead of the ngram
+ const char *lmtable::cmaxsuffptr(int* codes, int sz, unsigned int* size)
+ {
+ VERBOSE(3,"const char *lmtable::cmaxsuffptr(int* codes, int sz, unsigned int* size)\n");
+
+ if (sz==0) {
+ if (size!=NULL) *size=0;
+ return (char*) NULL;
+ }
+
+ if (sz>maxlev) sz=maxlev; //adjust n-gram level to table size
+
+#ifdef PS_CACHE_ENABLE
+ //cache hit
+ prob_and_state_t pst;
+
+ //cache hit
+ // if (prob_and_state_cache && sz==maxlev && prob_and_state_cache->get(codes,pst)) {
+ if (prob_and_state_cache[sz] && prob_and_state_cache[sz]->get(codes,pst)) {
+ if (size) *size = pst.statesize;
+ return pst.state;
+ }
+
+ //create the actual ngram
+ ngram ong(dict);
+ ong.pushc(codes,sz);
+ MY_ASSERT (ong.size == sz);
+
+ //cache miss
+ unsigned int isize; //internal state size variable
+ char* found=(char *)maxsuffptr(ong,&isize);
+
+ //cache insert
+ //IMPORTANT: this function updates only two fields (state, statesize) of the entry of the cache; the reminaing fields (logpr, bow, bol, extendible) are undefined; hence, it should not be used before the corresponding clprob()
+ if (ong.size>=maxlev) ong.size=maxlev;
+ // if (prob_and_state_cache && ong.size==maxlev) {
+ if (prob_and_state_cache[sz]) {
+ pst.state=found;
+ pst.statesize=isize;
+ // prob_and_state_cache->add(ong.wordp(maxlev),pst);
+ prob_and_state_cache[sz]->add(ong.wordp(ong.size),pst);
+ }
+ if (size!=NULL) *size=isize;
+ return found;
+#else
+ //create the actual ngram
+ ngram ong(dict);
+ ong.pushc(codes,sz);
+ MY_ASSERT (ong.size == sz);
+ /*
+ unsigned int isize; //internal state size variable
+ char* found=(char *) maxsuffptr(ong,&isize);
+ char* found2=(char *) maxsuffptr(ong,size);
+ if (size!=NULL) *size=isize;
+ return found;
+ */
+ return maxsuffptr(ong,size);
+#endif
+ }
+
+
+
+ //returns log10prob of n-gram
+ //bow: backoff weight
+ //bol: backoff level
+
+ //additional infos related to use in Moses:
+ //maxsuffptr: recombination state after the LM call
+ //statesize: lenght of the recombination state
+ //extensible: true if the deepest found ngram has successors
+ //lastbow: bow of the deepest found ngram
+
+ //non recursive version, also includes maxsuffptr
+ double lmtable::lprob(ngram ong,double* bow, int* bol, char** maxsuffptr,unsigned int* statesize,
+ bool* extendible, double *lastbow)
+ {
+ VERBOSE(3," lmtable::lprob(ngram) ong " << ong << "\n");
+
+ if (ong.size==0) return 0.0; //sanity check
+ if (ong.size>maxlev) ong.size=maxlev; //adjust n-gram level to table size
+
+ if (bow) *bow=0; //initialize back-off weight
+ if (bol) *bol=0; //initialize bock-off level
+
+
+ double rbow=0,lpr=0; //output back-off weight and logprob
+ float ibow,iprob; //internal back-off weight and logprob
+
+
+ if (isInverted) {
+ ngram ing=ong; //Inverted ngram TRIE
+
+ ing.invert(ong);
+ get(ing,ing.size,ing.size); // dig in the trie
+ if (ing.lev >0) { //found something?
+ iprob=ing.prob;
+ lpr = (double)(isQtable?Pcenters[ing.lev][(qfloat_t)iprob]:iprob);
+ if (*ong.wordp(1)==dict->oovcode()) lpr-=logOOVpenalty; //add OOV penalty
+ if (statesize) *statesize=MIN(ing.lev,(ing.size-1)); //find largest n-1 gram suffix
+ if (maxsuffptr) *maxsuffptr=ing.path[MIN(ing.lev,(ing.size-1))];
+ if (extendible) *extendible=succrange(ing.path[ing.lev],ing.lev)>0;
+ if (lastbow) *lastbow=(double) (isQtable?Bcenters[ing.lev][(qfloat_t)ing.bow]:ing.bow);
+ } else { // means a real unknown word!
+ lpr=-log(UNIGRAM_RESOLUTION)/M_LN10;
+ if (statesize) *statesize=0; //default statesize for zero-gram!
+ if (maxsuffptr) *maxsuffptr=NULL; //default stateptr for zero-gram!
+ }
+
+ if (ing.lev < ing.size) { //compute backoff weight
+ int depth=(ing.lev>0?ing.lev:1); //ing.lev=0 (real unknown word) is still a 1-gram
+ if (bol) *bol=ing.size-depth;
+ ing.size--; //get n-gram context
+ get(ing,ing.size,ing.size); // dig in the trie
+ if (ing.lev>0) { //found something?
+ //collect back-off weights
+ for (int l=depth; l<=ing.lev; l++) {
+ //start from first back-off level
+ MY_ASSERT(ing.path[l]!=NULL); //check consistency of table
+ ibow=this->bow(ing.path[l],tbltype[l]);
+ rbow+= (double) (isQtable?Bcenters[l][(qfloat_t)ibow]:ibow);
+ //avoids bad quantization of bow of <unk>
+ // if (isQtable && (*ing.wordp(1)==dict->oovcode())) {
+ if (isQtable && (*ing.wordp(ing.size)==dict->oovcode())) {
+ rbow-=(double)Bcenters[l][(qfloat_t)ibow];
+ }
+ }
+ }
+ }
+
+ if (bow) (*bow)=rbow;
+ return rbow + lpr;
+ } //Direct ngram TRIE
+ else {
+ MY_ASSERT((extendible == NULL) || (extendible && *extendible==false));
+ // MY_ASSERT(lastbow==NULL);
+ for (ngram ng=ong; ng.size>0; ng.size--) {
+ if (get(ng,ng.size,ng.size)) {
+ iprob=ng.prob;
+ lpr = (double)(isQtable?Pcenters[ng.size][(qfloat_t)iprob]:iprob);
+ if (*ng.wordp(1)==dict->oovcode()) lpr-=logOOVpenalty; //add OOV penalty
+ if (maxsuffptr || statesize) { //one extra step is needed if ng.size=ong.size
+ if (ong.size==ng.size) {
+ ng.size--;
+ get(ng,ng.size,ng.size);
+ }
+ if (statesize) *statesize=ng.size;
+ if (maxsuffptr) *maxsuffptr=ng.link; //we should check ng.link != NULL
+ }
+ return rbow+lpr;
+ } else {
+ if (ng.size==1) { //means a real unknow word!
+ if (maxsuffptr) *maxsuffptr=NULL; //default stateptr for zero-gram!
+ if (statesize) *statesize=0;
+ return rbow -log(UNIGRAM_RESOLUTION)/M_LN10;
+ } else { //compute backoff
+ if (bol) (*bol)++; //increase backoff level
+ if (ng.lev==(ng.size-1)) { //if search stopped at previous level
+ ibow=ng.bow;
+ rbow+= (double) (isQtable?Bcenters[ng.lev][(qfloat_t)ibow]:ibow);
+ //avoids bad quantization of bow of <unk>
+ if (isQtable && (*ng.wordp(2)==dict->oovcode())) {
+ rbow-=(double)Bcenters[ng.lev][(qfloat_t)ibow];
+ }
+ }
+ if (bow) (*bow)=rbow;
+ }
+
+ }
+
+ }
+ }
+ MY_ASSERT(0); //never pass here!!!
+ return 1.0;
+ }
+
+
+ //return log10 probsL use cache memory
+ double lmtable::clprob(ngram ong,double* bow, int* bol, char** state,unsigned int* statesize,bool* extendible)
+ {
+ VERBOSE(3,"double lmtable::clprob(ngram ong,double* bow, int* bol, char** state,unsigned int* statesize,bool* extendible) ong:|" << ong << "|\n");
+
+#ifdef TRACE_CACHELM
+ // if (probcache && ong.size==maxlev && sentence_id>0) {
+ if (probcache && sentence_id>0) {
+ *cacheout << sentence_id << " " << ong << "\n";
+ }
+#endif
+
+ if (ong.size==0) {
+ if (statesize!=NULL) *statesize=0;
+ if (state!=NULL) *state=NULL;
+ if (extendible!=NULL) *extendible=false;
+ return 0.0;
+ }
+
+ if (ong.size>maxlev) ong.size=maxlev; //adjust n-gram level to table size
+
+#ifdef PS_CACHE_ENABLE
+ double logpr = 0.0;
+ //cache hit
+ prob_and_state_t pst_get;
+
+ // if (prob_and_state_cache && ong.size==maxlev && prob_and_state_cache->get(ong.wordp(maxlev),pst_get)) {
+ if (prob_and_state_cache[ong.size] && prob_and_state_cache[ong.size]->get(ong.wordp(ong.size),pst_get)) {
+ logpr=pst_get.logpr;
+ if (bow) *bow = pst_get.bow;
+ if (bol) *bol = pst_get.bol;
+ if (state) *state = pst_get.state;
+ if (statesize) *statesize = pst_get.statesize;
+ if (extendible) *extendible = pst_get.extendible;
+
+ return logpr;
+ }
+
+ //cache miss
+
+ prob_and_state_t pst_add;
+ logpr = pst_add.logpr = lmtable::lprob(ong, &(pst_add.bow), &(pst_add.bol), &(pst_add.state), &(pst_add.statesize), &(pst_add.extendible));
+
+
+ if (bow) *bow = pst_add.bow;
+ if (bol) *bol = pst_add.bol;
+ if (state) *state = pst_add.state;
+ if (statesize) *statesize = pst_add.statesize;
+ if (extendible) *extendible = pst_add.extendible;
+
+
+ // if (prob_and_state_cache && ong.size==maxlev) {
+ // prob_and_state_cache->add(ong.wordp(maxlev),pst_add);
+ // }
+ if (prob_and_state_cache[ong.size]) {
+ prob_and_state_cache[ong.size]->add(ong.wordp(ong.size),pst_add);
+ }
+ return logpr;
+#else
+ return lmtable::lprob(ong, bow, bol, state, statesize, extendible);
+#endif
+ };
+
+
+ //return log10 probsL use cache memory
+ //this function simulates the clprob(ngram, ...) but it takes as input an array of codes instead of the ngram
+ double lmtable::clprob(int* codes, int sz, double* bow, int* bol, char** state,unsigned int* statesize,bool* extendible)
+ {
+ VERBOSE(3," double lmtable::clprob(int* codes, int sz, double* bow, int* bol, char** state,unsigned int* statesize,bool* extendible)\n");
+#ifdef TRACE_CACHELM
+ // if (probcache && sz==maxlev && sentence_id>0) {
+ if (probcache && sentence_id>0) {
+ *cacheout << sentence_id << "\n";
+ //print the codes of the vector ng
+ }
+#endif
+
+ if (sz==0) {
+ if (statesize!=NULL) *statesize=0;
+ if (state!=NULL) *state=NULL;
+ if (extendible!=NULL) *extendible=false;
+ return 0.0;
+ }
+
+ if (sz>maxlev) sz=maxlev; //adjust n-gram level to table size
+
+#ifdef PS_CACHE_ENABLE
+ double logpr;
+
+ //cache hit
+ prob_and_state_t pst_get;
+
+ // if (prob_and_state_cache && sz==maxlev && prob_and_state_cache->get(codes,pst_get)) {
+ if (prob_and_state_cache[sz] && prob_and_state_cache[sz]->get(codes,pst_get)) {
+
+ logpr=pst_get.logpr;
+ if (bow) *bow = pst_get.bow;
+ if (bol) *bol = pst_get.bol;
+ if (state) *state = pst_get.state;
+ if (statesize) *statesize = pst_get.statesize;
+ if (extendible) *extendible = pst_get.extendible;
+
+ return logpr;
+ }
+
+
+ //create the actual ngram
+ ngram ong(dict);
+ ong.pushc(codes,sz);
+ MY_ASSERT (ong.size == sz);
+
+ //cache miss
+ prob_and_state_t pst_add;
+ logpr = pst_add.logpr = lmtable::lprob(ong, &(pst_add.bow), &(pst_add.bol), &(pst_add.state), &(pst_add.statesize), &(pst_add.extendible));
+
+
+ if (bow) *bow = pst_add.bow;
+ if (bol) *bol = pst_add.bol;
+ if (state) *state = pst_add.state;
+ if (statesize) *statesize = pst_add.statesize;
+ if (extendible) *extendible = pst_add.extendible;
+
+
+ // if (prob_and_state_cache && ong.size==maxlev) {
+ // prob_and_state_cache->add(ong.wordp(maxlev),pst_add);
+ // }
+ if (prob_and_state_cache[sz]) {
+ prob_and_state_cache[sz]->add(ong.wordp(ong.size),pst_add);
+ }
+ return logpr;
+#else
+
+ //create the actual ngram
+ ngram ong(dict);
+ ong.pushc(codes,sz);
+ MY_ASSERT (ong.size == sz);
+
+ /*
+ logpr = lmtable::lprob(ong, bow, bol, state, statesize, extendible);
+ return logpr;
+ */
+ return lmtable::lprob(ong, bow, bol, state, statesize, extendible);
+#endif
+ };
+
+
+ int lmtable::succrange(node ndp,int level,table_entry_pos_t* isucc,table_entry_pos_t* esucc)
+ {
+ table_entry_pos_t first,last;
+ LMT_TYPE ndt=tbltype[level];
+
+ //get table boundaries for next level
+ if (level<maxlev) {
+ first = ndp>table[level]? bound(ndp-nodesize(ndt), ndt) : 0;
+ last = bound(ndp, ndt);
+ } else {
+ first=last=0;
+ }
+ if (isucc) *isucc=first;
+ if (esucc) *esucc=last;
+
+ return last-first;
+ }
+
+
+ void lmtable::stat(int level)
+ {
+ table_pos_t totmem=0,memory;
+ float mega=1024 * 1024;
+
+ cout.precision(2);
+
+ cout << "lmtable class statistics\n";
+
+ cout << "levels " << maxlev << "\n";
+ for (int l=1; l<=maxlev; l++) {
+ memory=(table_pos_t) cursize[l] * nodesize(tbltype[l]);
+ cout << "lev " << l
+ << " entries "<< cursize[l]
+ << " used mem " << memory/mega << "Mb\n";
+ totmem+=memory;
+ }
+
+ cout << "total allocated mem " << totmem/mega << "Mb\n";
+
+ cout << "total number of get and binary search calls\n";
+ for (int l=1; l<=maxlev; l++) {
+ cout << "level " << l << " get: " << totget[l] << " bsearch: " << totbsearch[l] << "\n";
+ }
+
+ if (level >1 ) lmtable::getDict()->stat();
+
+ stat_caches();
+
+ }
+
+ void lmtable::reset_mmap()
+ {
+#ifndef WIN32
+ if (memmap>0 and memmap<=maxlev)
+ for (int l=memmap; l<=maxlev; l++) {
+ VERBOSE(2,"resetting mmap at level:" << l << std::endl);
+ Munmap(table[l]-tableGaps[l],(table_pos_t) cursize[l]*nodesize(tbltype[l])+tableGaps[l],0);
+ table[l]=(char *)MMap(diskid,PROT_READ,
+ tableOffs[l], (table_pos_t)cursize[l]*nodesize(tbltype[l]),
+ &tableGaps[l]);
+ table[l]+=(table_pos_t)tableGaps[l];
+ }
+#endif
+ }
+
+ // ng: input n-gram
+
+ // *lk: prob of n-(*bol) gram
+ // *boff: backoff weight vector
+ // *bol: backoff level
+
+ double lmtable::lprobx(ngram ong,
+ double *lkp,
+ double *bop,
+ int *bol)
+ {
+ double bo, lbo, pr;
+ float ipr;
+ //int ipr;
+ ngram ng(dict), ctx(dict);
+
+ if(bol) *bol=0;
+ if(ong.size==0) {
+ if(lkp) *lkp=0;
+ return 0; // lprob ritorna 0, prima lprobx usava LOGZERO
+ }
+ if(ong.size>maxlev) ong.size=maxlev;
+ ctx = ng = ong;
+ bo=0;
+ ctx.shift();
+ while(!get(ng)) { // back-off
+
+ //OOV not included in dictionary
+ if(ng.size==1) {
+ pr = -log(UNIGRAM_RESOLUTION)/M_LN10;
+ if(lkp) *lkp=pr; // this is the innermost probability
+ pr += bo; //add all the accumulated back-off probability
+ return pr;
+ }
+ // backoff-probability
+ lbo = 0.0; //local back-off: default is logprob 0
+ if(get(ctx)) { //this can be replaced with (ng.lev==(ng.size-1))
+ ipr = ctx.bow;
+ lbo = isQtable?Bcenters[ng.size][(qfloat_t)ipr]:ipr;
+ //lbo = isQtable?Bcenters[ng.size][ipr]:*(float*)&ipr;
+ }
+ if(bop) *bop++=lbo;
+ if(bol) ++*bol;
+ bo += lbo;
+ ng.size--;
+ ctx.size--;
+ }
+ ipr = ng.prob;
+ pr = isQtable?Pcenters[ng.size][(qfloat_t)ipr]:ipr;
+ //pr = isQtable?Pcenters[ng.size][ipr]:*((float*)&ipr);
+ if(lkp) *lkp=pr;
+ pr += bo;
+ return pr;
+ }
+
+
+ // FABIO
+ table_entry_pos_t lmtable::wdprune(float *thr, int aflag)
+ {
+ //this function implements a method similar to the "Weighted Difference Method"
+ //described in "Scalable Backoff Language Models" by Kristie Seymore and Ronald Rosenfeld
+ int l;
+ ngram ng(lmtable::getDict(),0);
+
+ isPruned=true; //the table now might contain pruned n-grams
+
+ ng.size=0;
+
+ for(l=2; l<=maxlev; l++) wdprune(thr, aflag, ng, 1, l, 0, cursize[1]);
+ return 0;
+ }
+
+ // FABIO: LM pruning method
+
+ table_entry_pos_t lmtable::wdprune(float *thr, int aflag, ngram ng, int ilev, int elev, table_entry_pos_t ipos, table_entry_pos_t epos, double tlk, double bo, double *ts, double *tbs)
+ {
+ LMT_TYPE ndt=tbltype[ilev];
+ int ndsz=nodesize(ndt);
+ char *ndp;
+ float lk;
+ float ipr, ibo;
+ //int ipr, ibo;
+ table_entry_pos_t i, k, nk;
+
+ MY_ASSERT(ng.size==ilev-1);
+ //Note that ipos and epos are always larger than or equal to 0 because they are unsigned int
+ MY_ASSERT(epos<=cursize[ilev] && ipos<epos);
+
+ ng.pushc(0); //increase size of n-gram
+
+ for(i=ipos, nk=0; i<epos; i++) {
+
+ //scan table at next level ilev from position ipos
+ ndp = table[ilev]+(table_pos_t)i*ndsz;
+ *ng.wordp(1) = word(ndp);
+
+ //get probability
+ ipr = prob(ndp, ndt);
+ if(ipr==NOPROB) continue; // Has it been already pruned ??
+
+ if ((ilev == 1) && (*ng.wordp(ng.size) == getDict()->getcode(BOS_))) {
+ //the n-gram starts with the sentence start symbol
+ //do not consider is actual probability because it is not reliable (its frequency is manually set)
+ ipr = 0.0;
+ }
+ lk = ipr;
+
+ if(ilev<elev) { //there is an higher order
+
+ //get backoff-weight for next level
+ ibo = bow(ndp, ndt);
+ bo = ibo;
+
+ //get table boundaries for next level
+ table_entry_pos_t isucc,esucc;
+ succrange(ndp,ilev,&isucc,&esucc);
+
+ //table_entry_pos_t isucc = i>0 ? bound(ndp-ndsz, ndt) : 0;
+ //table_entry_pos_t esucc = bound(ndp, ndt);
+ if(isucc>=esucc) continue; // no successors
+
+ //look for n-grams to be pruned with this context (see
+ //back-off weight)
+ prune:
+ double nextlevel_ts=0, nextlevel_tbs=0;
+ k = wdprune(thr, aflag, ng, ilev+1, elev, isucc, esucc, tlk+lk, bo, &nextlevel_ts, &nextlevel_tbs);
+ //k is the number of pruned n-grams with this context
+ if(ilev!=elev-1) continue;
+ if(nextlevel_ts>=1 || nextlevel_tbs>=1) {
+ VERBOSE(2, "ng: " << ng <<" nextlevel_ts=" << nextlevel_ts <<" nextlevel_tbs=" << nextlevel_tbs <<" k=" << k <<" ns=" << esucc-isucc << "\n");
+ if(nextlevel_ts>=1) {
+ pscale(ilev+1, isucc, esucc, 0.999999/nextlevel_ts);
+ goto prune;
+ }
+ }
+ // adjusts backoff:
+ // 1-sum_succ(pr(w|ng)) / 1-sum_succ(pr(w|bng))
+ bo = log((1-nextlevel_ts)/(1-nextlevel_tbs))/M_LN10;
+ ibo=(float)bo;
+ bow(ndp, ndt, ibo);
+ } else { //we are at the highest level
+
+ //get probability of lower order n-gram
+ ngram bng = ng;
+ bng.size--;
+ double blk = lprob(bng);
+
+ double wd = pow(10., tlk+lk) * (lk-bo-blk);
+ if(aflag&&wd<0) wd=-wd;
+ if(wd > thr[elev-1]) { // kept
+ *ts += pow(10., lk);
+ *tbs += pow(10., blk);
+ } else { // discarded
+ ++nk;
+ prob(ndp, ndt, NOPROB);
+ }
+ }
+ }
+ return nk;
+ }
+
+ int lmtable::pscale(int lev, table_entry_pos_t ipos, table_entry_pos_t epos, double s)
+ {
+ LMT_TYPE ndt=tbltype[lev];
+ int ndsz=nodesize(ndt);
+ char *ndp;
+ float ipr;
+
+ s=log(s)/M_LN10;
+ ndp = table[lev]+ (table_pos_t) ipos*ndsz;
+ for(table_entry_pos_t i=ipos; i<epos; ndp+=ndsz,i++) {
+ ipr = prob(ndp, ndt);
+ if(ipr==NOPROB) continue;
+ ipr+=(float) s;
+ prob(ndp, ndt, ipr);
+ }
+ return 0;
+ }
+
+ //recompute table size by excluding pruned n-grams
+ table_entry_pos_t lmtable::ngcnt(table_entry_pos_t *cnt)
+ {
+ ngram ng(lmtable::getDict(),0);
+ memset(cnt, 0, (maxlev+1)*sizeof(*cnt));
+ ngcnt(cnt, ng, 1, 0, cursize[1]);
+ return 0;
+ }
+
+ //recursively compute size
+ table_entry_pos_t lmtable::ngcnt(table_entry_pos_t *cnt, ngram ng, int l, table_entry_pos_t ipos, table_entry_pos_t epos)
+ {
+
+ table_entry_pos_t i, isucc, esucc;
+ float ipr;
+ char *ndp;
+ LMT_TYPE ndt=tbltype[l];
+ int ndsz=nodesize(ndt);
+
+ ng.pushc(0);
+ for(i=ipos; i<epos; i++) {
+ ndp = table[l]+(table_pos_t) i*ndsz;
+ *ng.wordp(1)=word(ndp);
+ ipr=prob(ndp, ndt);
+ if(ipr==NOPROB) continue;
+ ++cnt[l];
+ if(l==maxlev) continue;
+ succrange(ndp,l,&isucc,&esucc);
+ if(isucc < esucc) ngcnt(cnt, ng, l+1, isucc, esucc);
+ }
+ return 0;
+ }
+}//namespace irstlm
+
diff --git a/src/lmtable.h b/src/lmtable.h
new file mode 100644
index 0000000..07e25d0
--- /dev/null
+++ b/src/lmtable.h
@@ -0,0 +1,660 @@
+// $Id: lmtable.h 3686 2010-10-15 11:55:32Z bertoldi $
+
+/******************************************************************************
+ IrstLM: IRST Language Model Toolkit
+ Copyright (C) 2006 Marcello Federico, ITC-irst Trento, Italy
+
+ This library is free software; you can redistribute it and/or
+ modify it under the terms of the GNU Lesser General Public
+ License as published by the Free Software Foundation; either
+ version 2.1 of the License, or (at your option) any later version.
+
+ This library is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+ Lesser General Public License for more details.
+
+ You should have received a copy of the GNU Lesser General Public
+ License along with this library; if not, write to the Free Software
+ Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+
+ ******************************************************************************/
+
+
+#ifndef MF_LMTABLE_H
+#define MF_LMTABLE_H
+
+#ifndef WIN32
+#include <sys/types.h>
+#include <sys/mman.h>
+#endif
+
+#include <math.h>
+#include <cstdlib>
+#include <string>
+#include <set>
+#include <limits>
+#include "util.h"
+#include "ngramcache.h"
+#include "dictionary.h"
+#include "n_gram.h"
+#include "lmContainer.h"
+
+#define MAX(a,b) (((a)>(b))?(a):(b))
+#define MIN(a,b) (((a)<(b))?(a):(b))
+
+#define LMTMAXLEV 20
+#define MAX_LINE 100000
+
+#ifndef LMTCODESIZE
+#define LMTCODESIZE (int)3
+#endif
+
+#define SHORTSIZE (int)2
+#define PTRSIZE (int)sizeof(char *)
+#define INTSIZE (int)4
+#define CHARSIZE (int)1
+
+#define PROBSIZE (int)4 //use float
+#define QPROBSIZE (int)1 //use qfloat_t
+//#define BOUNDSIZE (int)4 //use table_pos_t
+#define BOUNDSIZE (int)sizeof(table_entry_pos_t) //use table_pos_t
+
+#define UNIGRAM_RESOLUTION 10000000.0
+
+typedef enum {INTERNAL,QINTERNAL,LEAF,QLEAF} LMT_TYPE;
+typedef char* node;
+
+typedef enum {LMT_FIND, //!< search: find an entry
+ LMT_ENTER, //!< search: enter an entry
+ LMT_INIT, //!< scan: start scan
+ LMT_CONT //!< scan: continue scan
+} LMT_ACTION;
+
+typedef unsigned int table_entry_pos_t; //type for pointing to a full ngram in the table
+typedef unsigned long table_pos_t; // type for pointing to a single char in the table
+typedef unsigned char qfloat_t; //type for quantized probabilities
+
+//CHECK this part to HERE
+
+#define BOUND_EMPTY1 (numeric_limits<table_entry_pos_t>::max() - 2)
+#define BOUND_EMPTY2 (numeric_limits<table_entry_pos_t>::max() - 1)
+
+namespace irstlm {
+class lmtable: public lmContainer
+{
+ static const bool debug=true;
+
+ void loadtxt(std::istream& inp,const char* header,const char* filename,int mmap);
+ void loadtxt_ram(std::istream& inp,const char* header);
+ void loadtxt_mmap(std::istream& inp,const char* header,const char* outfilename);
+ void loadtxt_level(std::istream& inp,int l);
+
+ void loadbin(std::istream& inp,const char* header,const char* filename,int mmap);
+ void loadbin_header(std::istream& inp, const char* header);
+ void loadbin_dict(std::istream& inp);
+ void loadbin_codebook(std::istream& inp,int l);
+ void loadbin_level(std::istream& inp,int l);
+
+protected:
+ char* table[LMTMAXLEV+1]; //storage of all levels
+ LMT_TYPE tbltype[LMTMAXLEV+1]; //table type for each levels
+ table_entry_pos_t cursize[LMTMAXLEV+1]; //current size of levels
+
+ //current offset for in-memory tables (different for each level
+ //needed to manage partial tables
+ // mempos = diskpos - offset[level]
+ table_entry_pos_t tb_offset[LMTMAXLEV+1];
+
+ table_entry_pos_t maxsize[LMTMAXLEV+1]; //max size of levels
+ table_entry_pos_t* startpos[LMTMAXLEV+1]; //support vector to store start positions
+ char info[100]; //information put in the header
+
+ //statistics
+ int totget[LMTMAXLEV+1];
+ int totbsearch[LMTMAXLEV+1];
+
+ //probability quantization
+ bool isQtable;
+
+ //Incomplete LM table from distributed training
+ bool isItable;
+
+ //Table with reverted n-grams for fast access
+ bool isInverted;
+
+ //Table might contain pruned n-grams
+ bool isPruned;
+
+ int NumCenters[LMTMAXLEV+1];
+ float* Pcenters[LMTMAXLEV+1];
+ float* Bcenters[LMTMAXLEV+1];
+
+ double logOOVpenalty; //penalty for OOV words (default 0)
+ int dictionary_upperbound; //set by user
+ int backoff_state;
+
+ //improve access speed
+ int max_cache_lev;
+
+// NGRAMCACHE_t* prob_and_state_cache;
+ NGRAMCACHE_t* prob_and_state_cache[LMTMAXLEV+1];
+ NGRAMCACHE_t* lmtcache[LMTMAXLEV+1];
+ float ngramcache_load_factor;
+ float dictionary_load_factor;
+
+ //memory map on disk
+ int memmap; //level from which n-grams are accessed via mmap
+ int diskid;
+ off_t tableOffs[LMTMAXLEV+1];
+ off_t tableGaps[LMTMAXLEV+1];
+
+ // is this LM queried for knowing the matching order or (standard
+ // case) for score?
+ bool orderQuery;
+
+ //flag to enable/disable deletion of dict in the destructor
+ bool delete_dict;
+
+public:
+
+#ifdef TRACE_CACHELM
+ std::fstream* cacheout;
+ int sentence_id;
+#endif
+
+ dictionary *dict; // dictionary (words - macro tags)
+
+ lmtable(float nlf=0.0, float dlfi=0.0);
+
+ virtual ~lmtable();
+
+ table_entry_pos_t wdprune(float *thr, int aflag=0);
+ table_entry_pos_t wdprune(float *thr, int aflag, ngram ng, int ilev, int elev, table_entry_pos_t ipos, table_entry_pos_t epos, double lk=0, double bo=0, double *ts=0, double *tbs=0);
+ double lprobx(ngram ong, double *lkp=0, double *bop=0, int *bol=0);
+
+ table_entry_pos_t ngcnt(table_entry_pos_t *cnt);
+ table_entry_pos_t ngcnt(table_entry_pos_t *cnt, ngram ng, int l, table_entry_pos_t ipos, table_entry_pos_t epos);
+ int pscale(int lev, table_entry_pos_t ipos, table_entry_pos_t epos, double s);
+
+ void init_prob_and_state_cache();
+ void init_probcache() {
+ init_prob_and_state_cache();
+ }; //kept for back compatibility
+ void init_statecache() {}; //kept for back compatibility
+ void init_lmtcaches();
+// void init_lmtcaches(int uptolev);
+ void init_caches(int uptolev);
+
+ void used_prob_and_state_cache() const;
+ void used_lmtcaches() const;
+ void used_caches() const;
+
+
+ void delete_prob_and_state_cache();
+ void delete_probcache() {
+ delete_prob_and_state_cache();
+ }; //kept for back compatibility
+ void delete_statecache() {}; //kept for back compatibility
+ void delete_lmtcaches();
+ void delete_caches();
+
+ void stat_prob_and_state_cache();
+ void stat_lmtcaches();
+ void stat_caches();
+
+ void check_prob_and_state_cache_levels() const;
+ void check_probcache_levels() const {
+ check_prob_and_state_cache_levels();
+ }; //kept for back compatibility
+ void check_statecache_levels() const{}; //kept for back compatibility
+ void check_lmtcaches_levels() const;
+ void check_caches_levels() const;
+
+ void reset_prob_and_state_cache();
+ void reset_probcache() {
+ reset_prob_and_state_cache();
+ }; //kept for back compatibility
+ void reset_statecache() {}; //kept for back compatibility
+ void reset_lmtcaches();
+ void reset_caches();
+
+
+ bool are_prob_and_state_cache_active() const;
+ bool is_probcache_active() const {
+ return are_prob_and_state_cache_active();
+ }; //kept for back compatibility
+ bool is_statecache_active() const {
+ return are_prob_and_state_cache_active();
+ }; //kept for back compatibility
+ bool are_lmtcaches_active() const;
+ bool are_caches_active() const;
+
+ void reset_mmap();
+
+ //set the inverted flag to load ngrams in an inverted order
+ //this choice is disregarded if a binary LM is loaded,
+ //because the info is stored into the header
+ inline bool is_inverted(const bool flag) {
+ return isInverted=flag;
+ }
+ inline bool is_inverted() const {
+ return isInverted;
+ }
+
+ void configure(int n,bool quantized);
+
+ //set penalty for OOV words
+ inline double getlogOOVpenalty() const {
+ return logOOVpenalty;
+ }
+
+ inline double setlogOOVpenalty(int dub) {
+ MY_ASSERT(dub > dict->size());
+ dictionary_upperbound = dub;
+ return logOOVpenalty=log((double)(dictionary_upperbound - dict->size()))/M_LN10;
+ }
+
+ inline double setlogOOVpenalty(double oovp) {
+ return logOOVpenalty=oovp;
+ }
+
+ virtual int maxlevel() const {
+ return maxlev;
+ };
+ inline bool isQuantized() const {
+ return isQtable;
+ }
+
+
+ void savetxt(const char *filename);
+ void savebin(const char *filename);
+
+ void appendbin_level(int level, fstream &out, int mmap);
+ void appendbin_level_nommap(int level, fstream &out);
+ void appendbin_level_mmap(int level, fstream &out);
+
+ void savebin_level(int level, const char* filename, int mmap);
+ void savebin_level_nommap(int level, const char* filename);
+ void savebin_level_mmap(int level, const char* filename);
+ void savebin_dict(std::fstream& out);
+
+ void compact_all_levels(const char* filename);
+ void compact_single_level(int level, const char* filename);
+
+ void concatenate_all_levels(const char* fromfilename, const char* tofilename);
+ void concatenate_single_level(int level, const char* fromfilename, const char* tofilename);
+
+ void remove_all_levels(const char* filename);
+ void remove_single_level(int level, const char* filename);
+
+ void print_table_stat();
+ void print_table_stat(int level);
+
+ void dumplm(std::fstream& out,ngram ng, int ilev, int elev, table_entry_pos_t ipos,table_entry_pos_t epos);
+
+
+ void delete_level(int level, const char* outfilename, int mmap);
+ void delete_level_nommap(int level);
+ void delete_level_mmap(int level, const char* filename);
+
+ void resize_level(int level, const char* outfilename, int mmap);
+ void resize_level_nommap(int level);
+ void resize_level_mmap(int level, const char* filename);
+
+ inline void update_offset(int level, table_entry_pos_t value) { tb_offset[level]=value; };
+
+
+ void load(const std::string &filename, int mmap=0);
+ void load(std::istream& inp,const char* filename=NULL,const char* outfilename=NULL,int mmap=0);
+
+ void load_centers(std::istream& inp,int l);
+
+ void expand_level(int level, table_entry_pos_t size, const char* outfilename, int mmap);
+ void expand_level_nommap(int level, table_entry_pos_t size);
+ void expand_level_mmap(int level, table_entry_pos_t size, const char* outfilename);
+
+ void cpsublm(lmtable* sublmt, dictionary* subdict,bool keepunigr=true);
+
+ int reload(std::set<string> words);
+
+ void filter(const char* /* unused parameter: lmfile */) {};
+
+
+ virtual double lprob(ngram ng, double* bow=NULL,int* bol=NULL,char** maxsuffptr=NULL,unsigned int* statesize=NULL, bool* extendible=NULL, double* lastbow=NULL);
+ virtual double clprob(ngram ng, double* bow=NULL,int* bol=NULL,char** maxsuffptr=NULL,unsigned int* statesize=NULL,bool* extendible=NULL);
+ virtual double clprob(int* ng, int ngsize, double* bow=NULL,int* bol=NULL,char** maxsuffptr=NULL,unsigned int* statesize=NULL,bool* extendible=NULL);
+
+
+ void *search(int lev,table_entry_pos_t offs,table_entry_pos_t n,int sz,int *w, LMT_ACTION action,char **found=(char **)NULL);
+
+ int mybsearch(char *ar, table_entry_pos_t n, int size, char *key, table_entry_pos_t *idx);
+
+
+ int add(ngram& ng, float prob,float bow);
+ //template<typename TA, typename TB> int add(ngram& ng, TA prob,TB bow);
+
+ int addwithoffset(ngram& ng, float prob,float bow);
+ // template<typename TA, typename TB> int addwithoffset(ngram& ng, TA prob,TB bow);
+
+ void checkbounds(int level);
+
+ inline int get(ngram& ng) {
+ return get(ng,ng.size,ng.size);
+ }
+ int get(ngram& ng,int n,int lev);
+
+ int succscan(ngram& h,ngram& ng,LMT_ACTION action,int lev);
+
+ virtual const char *maxsuffptr(ngram ong, unsigned int* size=NULL);
+ virtual const char *cmaxsuffptr(ngram ong, unsigned int* size=NULL);
+ virtual const char *cmaxsuffptr(int* codes, int sz, unsigned int* size=NULL);
+
+ inline void putmem(char* ptr,int value,int offs,int size) {
+ MY_ASSERT(ptr!=NULL);
+ for (int i=0; i<size; i++)
+ ptr[offs+i]=(value >> (8 * i)) & 0xff;
+ };
+
+ inline void getmem(char* ptr,int* value,int offs,int size) {
+ MY_ASSERT(ptr!=NULL);
+ *value=ptr[offs] & 0xff;
+ for (int i=1; i<size; i++){
+ *value= *value | ( ( ptr[offs+i] & 0xff ) << (8 *i));
+ }
+ };
+
+ template<typename T>
+ inline void putmem(char* ptr,T value,int offs) {
+ MY_ASSERT(ptr!=NULL);
+ memcpy(ptr+offs, &value, sizeof(T));
+ };
+
+ template<typename T>
+ inline void getmem(char* ptr,T* value,int offs) {
+ MY_ASSERT(ptr!=NULL);
+ memcpy((void*)value, ptr+offs, sizeof(T));
+ };
+
+
+ int nodesize(LMT_TYPE ndt) {
+ switch (ndt) {
+ case INTERNAL:
+ return LMTCODESIZE + PROBSIZE + PROBSIZE + BOUNDSIZE;
+ case QINTERNAL:
+ return LMTCODESIZE + QPROBSIZE + QPROBSIZE + BOUNDSIZE;
+ case LEAF:
+ return LMTCODESIZE + PROBSIZE;
+ case QLEAF:
+ return LMTCODESIZE + QPROBSIZE;
+ default:
+ MY_ASSERT(0);
+ return 0;
+ }
+ }
+
+ inline int word(node nd,int value=-1) {
+ int offset=0;
+
+ if (value==-1)
+ getmem(nd,&value,offset,LMTCODESIZE);
+ else
+ putmem(nd,value,offset,LMTCODESIZE);
+
+ return value;
+ };
+
+
+ int codecmp(node a,node b) {
+ register int i,result;
+ for (i=(LMTCODESIZE-1); i>=0; i--) {
+ result=(unsigned char)a[i]-(unsigned char)b[i];
+ if(result) return result;
+ }
+ return 0;
+ };
+
+ int codediff(node a,node b) {
+ return word(a)-word(b);
+ };
+
+
+ inline float prob(node nd,LMT_TYPE ndt) {
+ int offs=LMTCODESIZE;
+
+ float fv;
+ unsigned char cv;
+ switch (ndt) {
+ case INTERNAL:
+ getmem(nd,&fv,offs);
+ return fv;
+ case QINTERNAL:
+ getmem(nd,&cv,offs);
+ return (float) cv;
+ case LEAF:
+ getmem(nd,&fv,offs);
+ return fv;
+ case QLEAF:
+ getmem(nd,&cv,offs);
+ return (float) cv;
+ default:
+ MY_ASSERT(0);
+ return 0;
+ }
+ };
+
+ template<typename T>
+ inline T prob(node nd, LMT_TYPE ndt, T value) {
+ int offs=LMTCODESIZE;
+
+ switch (ndt) {
+ case INTERNAL:
+ putmem(nd, value,offs);
+ break;
+ case QINTERNAL:
+ putmem(nd,(unsigned char) value,offs);
+ break;
+ case LEAF:
+ putmem(nd, value,offs);
+ break;
+ case QLEAF:
+ putmem(nd,(unsigned char) value,offs);
+ break;
+ default:
+ MY_ASSERT(0);
+ return (T) 0;
+ }
+
+ return value;
+ };
+
+
+ inline float bow(node nd,LMT_TYPE ndt) {
+ int offs=LMTCODESIZE+(ndt==QINTERNAL?QPROBSIZE:PROBSIZE);
+
+ float fv;
+ unsigned char cv;
+ switch (ndt) {
+ case INTERNAL:
+ getmem(nd,&fv,offs);
+ return fv;
+ case QINTERNAL:
+ getmem(nd,&cv,offs);
+ return (float) cv;
+ case LEAF:
+ getmem(nd,&fv,offs);
+ return fv;
+ case QLEAF:
+ getmem(nd,&cv,offs);
+ return (float) cv;
+ default:
+ MY_ASSERT(0);
+ return 0;
+ }
+ };
+
+ template<typename T>
+ inline T bow(node nd,LMT_TYPE ndt, T value) {
+ int offs=LMTCODESIZE+(ndt==QINTERNAL?QPROBSIZE:PROBSIZE);
+
+ switch (ndt) {
+ case INTERNAL:
+ putmem(nd, value,offs);
+ break;
+ case QINTERNAL:
+ putmem(nd,(unsigned char) value,offs);
+ break;
+ case LEAF:
+ putmem(nd, value,offs);
+ break;
+ case QLEAF:
+ putmem(nd,(unsigned char) value,offs);
+ break;
+ default:
+ MY_ASSERT(0);
+ return 0;
+ }
+
+ return value;
+ };
+
+
+ inline table_entry_pos_t boundwithoffset(node nd,LMT_TYPE ndt, int level){ return bound(nd,ndt) - tb_offset[level+1]; }
+
+ inline table_entry_pos_t boundwithoffset(node nd,LMT_TYPE ndt, table_entry_pos_t value, int level){ return bound(nd, ndt, value + tb_offset[level+1]); }
+
+ // table_entry_pos_t bound(node nd,LMT_TYPE ndt, int level=0) {
+ table_entry_pos_t bound(node nd,LMT_TYPE ndt) {
+
+ int offs=LMTCODESIZE+2*(ndt==QINTERNAL?QPROBSIZE:PROBSIZE);
+
+ table_entry_pos_t value;
+
+ getmem(nd,&value,offs);
+
+ // value -= tb_offset[level+1];
+
+ return value;
+ };
+
+
+ // table_entry_pos_t bound(node nd,LMT_TYPE ndt, table_entry_pos_t value, int level=0) {
+ table_entry_pos_t bound(node nd,LMT_TYPE ndt, table_entry_pos_t value) {
+
+ int offs=LMTCODESIZE+2*(ndt==QINTERNAL?QPROBSIZE:PROBSIZE);
+
+ // value += tb_offset[level+1];
+
+ putmem(nd,value,offs);
+
+ return value;
+ };
+
+ //template<typename T> T boundwithoffset(node nd,LMT_TYPE ndt, T value, int level);
+
+ /*
+ table_entry_pos_t boundwithoffset(node nd,LMT_TYPE ndt, int level) {
+
+ int offs=LMTCODESIZE+2*(ndt==QINTERNAL?QPROBSIZE:PROBSIZE);
+
+ table_entry_pos_t value;
+
+ getmem(nd,&value,offs);
+ return value;
+ // return value-tb_offset[level+1];
+ };
+ */
+
+ /*
+ table_entry_pos_t boundwithoffset(node nd,LMT_TYPE ndt, table_entry_pos_t value, int level) {
+
+ int offs=LMTCODESIZE+2*(ndt==QINTERNAL?QPROBSIZE:PROBSIZE);
+
+ putmem(nd,value,offs);
+
+ return value;
+ // return value+tb_offset[level+1];
+ };
+ */
+
+ /*
+ inline table_entry_pos_t bound(node nd,LMT_TYPE ndt) {
+
+ int offs=LMTCODESIZE+2*(ndt==QINTERNAL?QPROBSIZE:PROBSIZE);
+
+ table_entry_pos_t value;
+
+ getmem(nd,&value,offs);
+ return value;
+ };
+
+ template<typename T>
+ inline T bound(node nd,LMT_TYPE ndt, T value) {
+
+ int offs=LMTCODESIZE+2*(ndt==QINTERNAL?QPROBSIZE:PROBSIZE);
+
+ putmem(nd,value,offs);
+
+ return value;
+ };
+ */
+ //returns the indexes of the successors of a node
+ int succrange(node ndp,int level,table_entry_pos_t* isucc=NULL,table_entry_pos_t* esucc=NULL);
+
+ void stat(int lev=0);
+ void printTable(int level);
+
+ virtual inline void setDict(dictionary* d) {
+ if (delete_dict==true && dict) delete dict;
+ dict=d;
+ delete_dict=false;
+ };
+
+ inline dictionary* getDict() const {
+ return dict;
+ };
+
+ inline table_entry_pos_t getCurrentSize(int l) const {
+ return cursize[l];
+ };
+
+ inline void setOrderQuery(bool v) {
+ orderQuery = v;
+ }
+ inline bool isOrderQuery() const {
+ return orderQuery;
+ }
+
+ inline float GetNgramcacheLoadFactor() {
+ return ngramcache_load_factor;
+ }
+ inline float GetDictionaryLoadFactor() {
+ return ngramcache_load_factor;
+ }
+
+ //never allow the increment of the dictionary through this function
+ inline virtual void dictionary_incflag(const bool flag) {
+ UNUSED(flag);
+ };
+
+ inline virtual bool filter(const string sfilter, lmtable* sublmt, const string skeepunigrams) {
+ std::cerr << "filtering... \n";
+ dictionary *dict=new dictionary((char *)sfilter.c_str());
+
+ cpsublm(sublmt, dict,(skeepunigrams=="yes"));
+ delete dict;
+ std::cerr << "...done\n";
+ return true;
+ }
+
+
+ inline virtual bool is_OOV(int code) {
+ return (code == dict->oovcode());
+ };
+
+};
+
+}//namespace irstlm
+
+#endif
+
diff --git a/src/mdiadapt.cpp b/src/mdiadapt.cpp
new file mode 100644
index 0000000..ba60da1
--- /dev/null
+++ b/src/mdiadapt.cpp
@@ -0,0 +1,2175 @@
+/******************************************************************************
+ IrstLM: IRST Language Model Toolkit
+ Copyright (C) 2006 Marcello Federico, ITC-irst Trento, Italy
+
+ This library is free software; you can redistribute it and/or
+ modify it under the terms of the GNU Lesser General Public
+ License as published by the Free Software Foundation; either
+ version 2.1 of the License, or (at your option) any later version.
+
+ This library is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+ Lesser General Public License for more details.
+
+ You should have received a copy of the GNU Lesser General Public
+ License along with this library; if not, write to the Free Software
+ Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+
+ ******************************************************************************/
+
+#include <cmath>
+#include <string>
+#include "util.h"
+#include "mfstream.h"
+#include "mempool.h"
+#include "htable.h"
+#include "dictionary.h"
+#include "n_gram.h"
+#include "mempool.h"
+#include "ngramcache.h"
+#include "ngramtable.h"
+#include "normcache.h"
+#include "interplm.h"
+#include "mdiadapt.h"
+#include "shiftlm.h"
+#include "lmtable.h"
+
+using namespace std;
+
+namespace irstlm {
+
+#ifdef MDIADAPTLM_CACHE_ENABLE
+#if MDIADAPTLM_CACHE_ENABLE==0
+#undef MDIADAPTLM_CACHE_ENABLE
+#endif
+#endif
+
+#ifdef MDIADAPTLM_CACHE_ENABLE
+ bool mdiadaptlm::mdiadaptlm_cache_enable=true;
+#else
+ bool mdiadaptlm::mdiadaptlm_cache_enable=false;
+#endif
+
+ //
+ //Minimum discrimination adaptation for interplm
+ //
+ mdiadaptlm::mdiadaptlm(char* ngtfile,int depth,TABLETYPE tbtype):
+ interplm(ngtfile,depth,tbtype)
+ {
+ adaptlev=0;
+ forelm=NULL;
+ cache=NULL;
+ m_save_per_level=true;
+ };
+
+ mdiadaptlm::~mdiadaptlm()
+ {
+ if (cache) delete cache;
+ delete_caches();
+ };
+
+ void mdiadaptlm::delete_caches(int level)
+ {
+ if (probcache[level]) delete probcache[level];
+ if (backoffcache[level]) delete backoffcache[level];
+ };
+
+ void mdiadaptlm::delete_caches()
+ {
+#ifdef MDIADAPTLM_CACHE_ENABLE
+ for (int i=0; i<=max_caching_level; i++) delete_caches(i);
+
+ delete [] probcache;
+ delete [] backoffcache;
+#endif
+ };
+
+ void mdiadaptlm::caches_stat()
+ {
+#ifdef MDIADAPTLM_CACHE_ENABLE
+ for (int i=1; i<=max_caching_level; i++) {
+ if (probcache[i]) {
+ cerr << "Statistics of probcache at level " << i << " (of " << maxlevel() << ") ";
+ probcache[i]->stat();
+ }
+ if (backoffcache[i]) {
+ cerr << "Statistics of backoffcache at level " << i << " (of " << maxlevel() << ") ";
+ backoffcache[i]->stat();
+ }
+ }
+#endif
+ };
+
+
+ void mdiadaptlm::create_caches(int mcl)
+ {
+ max_caching_level=(mcl>=0 && mcl<lmsize())?mcl:lmsize()-1;
+
+ probcache = new NGRAMCACHE_t*[max_caching_level+1]; //index 0 will never be used, index=max_caching_level is not used
+ backoffcache = new NGRAMCACHE_t*[max_caching_level+1]; //index 0 will never be used, index=max_caching_level is not used
+ for (int i=0; i<=max_caching_level; i++) {
+ probcache[i]=NULL;
+ backoffcache[i]=NULL;
+ }
+
+ init_caches();
+ }
+
+
+ void mdiadaptlm::init_caches(int level)
+ {
+ MY_ASSERT(probcache[level]==NULL);
+ MY_ASSERT(backoffcache[level]==NULL);
+ probcache[level]=new NGRAMCACHE_t(level,sizeof(double),400000);
+ backoffcache[level]=new NGRAMCACHE_t(level,sizeof(double),400000);
+ };
+
+ void mdiadaptlm::init_caches()
+ {
+#ifdef MDIADAPTLM_CACHE_ENABLE
+ for (int i=1; i<=max_caching_level; i++) init_caches(i);
+#endif
+ };
+
+ void mdiadaptlm::check_cache_levels(int level)
+ {
+ if (probcache[level] && probcache[level]->isfull()) probcache[level]->reset(probcache[level]->cursize());
+ if (backoffcache[level] && backoffcache[level]->isfull()) backoffcache[level]->reset(backoffcache[level]->cursize());
+ };
+
+ void mdiadaptlm::check_cache_levels()
+ {
+#ifdef MDIADAPTLM_CACHE_ENABLE
+ for (int i=1; i<=max_caching_level; i++) check_cache_levels(i);
+#endif
+ };
+
+ void mdiadaptlm::reset_caches(int level)
+ {
+ if (probcache[level]) probcache[level]->reset(MAX(probcache[level]->cursize(),probcache[level]->maxsize()));
+ if (backoffcache[level]) backoffcache[level]->reset(MAX(backoffcache[level]->cursize(),backoffcache[level]->maxsize()));
+ };
+
+ void mdiadaptlm::reset_caches()
+ {
+#ifdef MDIADAPTLM_CACHE_ENABLE
+ for (int i=1; i<=max_caching_level; i++) reset_caches(i);
+#endif
+ };
+
+
+ inline NGRAMCACHE_t* mdiadaptlm::get_probcache(int level)
+ {
+ return probcache[level];
+ }
+
+ inline NGRAMCACHE_t* mdiadaptlm::get_backoffcache(int level)
+ {
+ return backoffcache[level];
+ }
+
+ int mdiadaptlm::scalefact(char *ngtfile)
+ {
+ if (forelm!=NULL) delete forelm;
+ if (cache!=NULL) delete cache;
+ cache=new normcache(dict);
+
+ forelm=new shiftbeta(ngtfile,1);
+ forelm->train();
+
+ //compute oov scalefact term
+ ngram fng(forelm->dict,1);
+ ngram ng(dict,1);
+ int* w=fng.wordp(1);
+
+ oovscaling=1.0;
+ for ((*w)=0; (*w)<forelm->dict->size(); (*w)++)
+ if ((*w) != forelm->dict->oovcode()) {
+ ng.trans(fng);
+ if (*ng.wordp(1)==dict->oovcode()) {
+ cerr << "adaptation file contains new words: use -ao=yes option\n";
+ exit(1);
+ }
+ //forbidden situation
+ oovscaling-=backunig(ng);
+ }
+ *w=forelm->dict->oovcode();
+ oovscaling=foreunig(fng)/oovscaling;
+
+ return 1;
+ };
+
+ int mdiadaptlm::savescalefactor(char* filename)
+ {
+
+ ngram ng(dict,1);
+ int* w=ng.wordp(1);
+
+ mfstream out(filename,ios::out);
+
+ out << "\n\\data\\" << "\nngram 1=" << dict->size() << "\n\n1grams:\n";
+
+ for ((*w)=0; (*w)<dict->size(); (*w)++) {
+ double ratio=scalefact(ng);
+ out << (float) (ratio?log10(ratio):-99);
+ if (*w==dict->oovcode())
+ out << "\t" << "<unk>\n";
+ else
+ out << "\t" << (char *)dict->decode(*w) << "\n";
+
+ }
+ out << "\\end\\\n";
+
+ return 1;
+ }
+
+ double mdiadaptlm::scalefact(ngram ng)
+ {
+ ngram fng(forelm->dict,1);
+ fng.trans(ng);
+ if (*fng.wordp(1)==forelm->dict->oovcode())
+ return pow(oovscaling,gis_step);
+ else {
+ double prback=backunig(ng);
+ double prfore=foreunig(ng);
+ return pow(prfore/prback,gis_step);
+ }
+ }
+
+
+ double mdiadaptlm::foreunig(ngram ng)
+ {
+
+ double fstar,lambda;
+
+ forelm->discount(ng,1,fstar,lambda);
+
+ return fstar;
+ }
+
+ double mdiadaptlm::backunig(ngram ng)
+ {
+
+ double fstar,lambda;
+
+ discount(ng,1,fstar,lambda,0);
+
+ return fstar;
+ };
+
+
+
+ int mdiadaptlm::adapt(char* ngtfile,int alev,double step)
+ {
+
+ if (alev > lmsize() || alev<=0) {
+ cerr << "setting adaptation level to " << lmsize() << "\n";
+ alev=lmsize();
+ }
+ adaptlev=alev;
+
+
+ cerr << "adapt ....";
+ gis_step=step;
+
+ if (ngtfile==NULL) {
+ cerr << "adaptation file is missing\n";
+ exit(1);
+ }
+
+ //compute the scaling factor;
+
+ scalefact(ngtfile);
+
+ //compute 1-gram zeta
+ ngram ng(dict,2);
+ int* w=ng.wordp(1);
+
+ cerr << "precomputing 1-gram normalization ...\n";
+ zeta0=0;
+ for ((*w)=0; (*w)<dict->size(); (*w)++)
+ zeta0+=scalefact(ng) * backunig(ng);
+
+ if (alev==1) return 1 ;
+
+ cerr << "precomputing 2-gram normalization:\n";
+
+ //precompute the bigram normalization
+ w=ng.wordp(2);
+ *ng.wordp(1)=0;
+
+ for ((*w)=0; (*w)<dict->size(); (*w)++) {
+ zeta(ng,2);
+ if ((*w % 1000)==0) cerr << ".";
+ }
+
+ cerr << "done\n";
+
+ return 1;
+ };
+
+
+ double mdiadaptlm::zeta(ngram ng,int size)
+ {
+
+ MY_ASSERT(size>=1);
+
+ double z=0; // compute normalization term
+
+ ng.size=size;
+
+ if (size==1) return zeta0;
+ else { //size>1
+
+ //check in the 2gr and 3gr cache
+ if (size <=3 && cache->get(ng,size,z)) return z;
+
+ double fstar,lambda;
+ ngram histo=ng;
+ int succ=0;
+
+ discount(ng,size,fstar,lambda,(int)0);
+
+ if ((lambda<1) && get(histo,size,size-1)) {
+ ;
+
+ //scan all its successors
+ succ=0;
+
+ succscan(histo,ng,INIT,size);
+ while(succscan(histo,ng,CONT,size)) {
+
+ discount(ng,size,fstar,lambda,0);
+ if (fstar>0) {
+ z+=(scalefact(ng) * fstar);
+ succ++;
+ //cerr << ng << "zeta= " << z << "\n";
+ }
+ }
+ }
+
+ z+=lambda*zeta(ng,size-1);
+
+ if (size<=3 && succ>1) cache->put(ng,size,z);
+
+ return z;
+ }
+
+ }
+
+
+ int mdiadaptlm::discount(ngram ng_,int size,double& fstar,double& lambda,int /* unused parameter: cv */)
+ {
+ VERBOSE(3,"mdiadaptlm::discount(ngram ng_,int size,double& fstar,double& lambda,int)) ng_:|" << ng_ << "| size:" << size << std::endl);
+
+ ngram ng(dict);
+ ng.trans(ng_);
+
+ double __fstar, __lambda;
+ bool lambda_cached=0;
+ int size_lambda=size-1;
+
+ ngram histo=ng;
+ histo.shift();
+
+ if (size_lambda>0 && histo.size>=size_lambda) {
+#ifdef MDIADAPTLM_CACHE_ENABLE
+ if (size_lambda<=max_caching_level) {
+ //backoffcache hit
+ if (backoffcache[size_lambda] && backoffcache[size_lambda]->get(histo.wordp(size_lambda),__lambda))
+ lambda_cached=1;
+ }
+#endif
+ }
+
+ discount(ng,size,__fstar,__lambda,0);
+
+ if ((size>0) && (size<=adaptlev) && (__lambda<1)) {
+
+ if (size>1) {
+ double numlambda, numfstar, den;
+ numfstar=scalefact(ng);
+ den=zeta(ng,size);
+ __fstar=__fstar * numfstar/den;
+ if (!lambda_cached) {
+ numlambda=zeta(ng,size-1);
+ __lambda=__lambda * numlambda/den;
+ }
+ } else if (size==1) {
+ double ratio;
+ ratio=scalefact(ng)/zeta0;
+ __fstar=__fstar * ratio;
+ if (!lambda_cached) {
+ __lambda=__lambda * ratio;
+ }
+ } else {
+ //size==0 do nothing
+ }
+ }
+
+#ifdef MDIADAPTLM_CACHE_ENABLE
+ //backoffcache insert
+ if (!lambda_cached && size_lambda>0 && size_lambda<=max_caching_level && histo.size>=size_lambda && backoffcache[size_lambda])
+ backoffcache[size_lambda]->add(histo.wordp(size_lambda),__lambda);
+#endif
+
+ lambda=__lambda;
+ fstar=__fstar;
+
+ return 1;
+ }
+
+ int mdiadaptlm::compute_backoff()
+ {
+ VERBOSE(3,"mdiadaptlm::compute_backoff() ");
+ if (m_save_per_level){
+ VERBOSE(3," per level ...\n");
+ return mdiadaptlm::compute_backoff_per_level();
+ }else{
+ VERBOSE(3," per word ...\n");
+ return mdiadaptlm::compute_backoff_per_word();
+ }
+ }
+
+ int mdiadaptlm::compute_backoff_per_level()
+ {
+ VERBOSE(3,"mdiadaptlm::compute_backoff_per_level()\n");
+ double fstar,lambda;
+
+ this->backoff=1;
+
+ for (int size=1; size<lmsize(); size++) {
+
+ ngram hg(dict,size);
+
+ scan(hg,INIT,size);
+
+ while(scan(hg,CONT,size)) {
+ ngram ng=hg;
+ ng.pushc(0); //ng.size is now hg.size+1
+ double pr=1.0;
+
+ succscan(hg,ng,INIT,size+1);
+ while(succscan(hg,ng,CONT,size+1)) {
+ mdiadaptlm::discount(ng,ng.size,fstar,lambda);
+ if (fstar>0){
+ ng.size=ng.size-1;
+ pr -= mdiadaptlm::prob(ng,size);
+ }
+ }
+
+ MY_ASSERT(pr>=LOWER_SINGLE_PRECISION_OF_0 && pr<=UPPER_SINGLE_PRECISION_OF_1);
+
+ boff(hg.link,pr);
+ }
+ }
+
+ VERBOSE(3,"mdiadaptlm::compute_backoff_per_level() DONE\n");
+
+ return 1;
+ }
+
+
+ int mdiadaptlm::compute_backoff_per_word()
+ {
+ cerr << "Current implementation does not support the usage of backoff (-bo=yes) mixture models (-lm=mix) combined with the per-word saving (-saveperllevel=no)." << endl;
+ cerr << "Please, either choose a per-level saving (-saveperllevel=yes) or do not use backoff (-bo=no) " << endl;
+
+ exit(1);
+ }
+
+
+ double mdiadaptlm::prob2(ngram ng,int size,double& fstar)
+ {
+ double lambda;
+
+ mdiadaptlm::discount(ng,size,fstar,lambda);
+
+ if (size>1)
+ return fstar + lambda * prob(ng,size-1);
+ else
+ return fstar;
+ }
+
+
+ //inline double mdiadaptlm::prob(ngram ng,int size){
+ double mdiadaptlm::prob(ngram ng,int size)
+ {
+ double fstar,lambda,bo;
+ return prob(ng,size,fstar,lambda,bo);
+ }
+
+ double mdiadaptlm::prob(ngram ng,int size,double& fstar,double& lambda, double& bo)
+ {
+ VERBOSE(3,"mdiadaptlm::prob(ngram ng,int size,double& fstar,double& lambda, double& bo) ng:|" << ng << "| size:" << size << std::endl);
+ double pr;
+
+#ifdef MDIADAPTLM_CACHE_ENABLE
+ //probcache hit
+ if (size<=max_caching_level && probcache[size] && ng.size>=size && probcache[size]->get(ng.wordp(size),pr))
+ return pr;
+#endif
+
+ //probcache miss
+ mdiadaptlm::bodiscount(ng,size,fstar,lambda,bo);
+ VERBOSE(3,"mdiadaptlm::prob(ngram ng,int size,double& fstar,double& lambda, double& bo) after bodiscount @@@@@@@@@ ng:|" << ng << "| size:" << size << "| fstar:" << fstar << "| lambda:" << lambda << "| bo:" << bo << std::endl);
+ if (fstar>UPPER_SINGLE_PRECISION_OF_1 || lambda>UPPER_SINGLE_PRECISION_OF_1) {
+ cerr << "wrong probability: " << ng
+ << " , size " << size
+ << " , fstar " << fstar
+ << " , lambda " << lambda << "\n";
+ fstar=(fstar>UPPER_SINGLE_PRECISION_OF_1?UPPER_SINGLE_PRECISION_OF_1:fstar);
+ lambda=(lambda>UPPER_SINGLE_PRECISION_OF_1?UPPER_SINGLE_PRECISION_OF_1:lambda);
+ //exit(1);
+ }
+ if (backoff) {
+ if (size>1) {
+ if (fstar>0){
+ pr=fstar;
+ }else {
+ if (lambda<1){
+ pr = lambda/bo * prob(ng,size-1);
+ }else {
+ MY_ASSERT(lambda<UPPER_SINGLE_PRECISION_OF_1);
+ pr = prob(ng,size-1);
+ }
+ }
+ } else
+ pr = fstar;
+ }
+
+ else { //interpolation
+ if (size>1)
+ pr = fstar + lambda * prob(ng,size-1);
+ else
+ pr = fstar;
+ }
+
+#ifdef MDIADAPTLM_CACHE_ENABLE
+ //probcache insert
+ if (size<=max_caching_level && probcache[size] && ng.size>=size)
+ probcache[size]->add(ng.wordp(size),pr);
+#endif
+ VERBOSE(3,"mdiadaptlm::prob(ngram ng,int size,double& fstar,double& lambda, double& bo) returning ng:|" << ng << "| pr:" << pr << std::endl);
+ return pr;
+ }
+
+
+ int mdiadaptlm::bodiscount(ngram ng_,int size,double& fstar,double& lambda,double& bo)
+ {
+ VERBOSE(3,"mdiadaptlm::bodiscount(ngram ng_,int size,double& fstar,double& lambda,double& bo) ng_:|" << ng_ << "| size:" << size << std::endl);
+ ngram ng(dict);
+ ng.trans(ng_);
+
+ mdiadaptlm::discount(ng,size,fstar,lambda);
+
+ bo=1.0;
+
+ if (backoff) { //get back-off probability
+
+ if (size>1 && lambda<1) {
+
+ ngram hg=ng;
+
+ // cerr<< "hg:|" << hg << "| size:|" << size << "|" << endl;
+ if (! get(hg,size,size-1)){
+ cerr << "ERROR: int mdiadaptlm::bodiscount(ngram ng_,int size,double& fstar,double& lambda,double& bo) -> get(hg,size,size-1) returns NULL\n";
+ }
+ MY_ASSERT(get(hg,size,size-1));
+
+ bo=boff(hg.link);
+
+ // if (lambda > bo){
+ // cerr << " mdiadaptlm::bodiscount ERROR: " << " lambda:" << lambda << " bo:" << bo << "\n";
+ // exit(1);
+ // }
+ }
+ }
+
+ return 1;
+ }
+
+
+ double mdiadaptlm::txclprob(ngram ng,int size)
+ {
+
+ double fstar,lambda;
+
+ if (size>1) {
+ mdiadaptlm::discount(ng,size,fstar,lambda);
+ return fstar + lambda * txclprob(ng,size-1);
+ } else {
+ double freq=1;
+ if ((*ng.wordp(1)!=dict->oovcode()) && get(ng,1,1))
+ freq+=ng.freq;
+
+ double N=totfreq()+dict->dub()-dict->size();
+ return freq/N;
+ }
+ }
+
+
+ int mdiadaptlm::netsize()
+ {
+ double fstar,lambda;
+ int size,totsize;
+ ngram ng(dict);
+
+ cerr << "Computing LM size:\n";
+
+ totsize=dict->size() * 2;
+
+ cout << "1-gram " << totsize << "\n";
+
+ for (int i=2; i<=maxlevel(); i++) {
+
+ size=0;
+
+ scan(ng,INIT,i);
+
+ while (scan(ng,CONT,i)) {
+
+ mdiadaptlm::discount(ng,i,fstar,lambda);
+
+ if (fstar>0) size++;
+
+ }
+
+ size+=size * (i<maxlevel());
+
+ totsize+=size;
+
+ cout << i << "-gram " << totsize << "\n";
+
+ }
+
+ return totsize;
+ }
+
+
+
+ /*
+ * trigram file format:
+
+ --------------------------------
+
+ <idx> dictionary length
+
+ repeat [ dictionary length ] {
+ <newline terminated string> word;
+ }
+
+ while [ first word != STOP ] {
+ <idx> first word
+ <idx> number of successors
+ repeat [ number of successors ] {
+ <idx> second word
+ <float> prob
+ }
+ }
+
+ <idx> STOP
+
+ while [ first word != STOP ] {
+ <idx> first word
+ <idx> number of successor sets
+ repeat [ number of successor sets ] {
+ <idx> second word
+ <idx> number of successors
+ repeat [ number of successors ] {
+ <idx> third word
+ <float> prob
+ }
+ }
+ }
+
+ <idx> STOP
+
+ */
+
+
+ //void writeNull(mfbstream& out,unsigned short nullCode,float nullProb){
+ // out.writex(&nullCode,sizeof(short));
+ // out.writex(&nullProb,sizeof(float));
+ //}
+
+
+ int swapbytes(char *p, int sz, int n)
+ {
+ char c,*l,*h;
+ if((n<1) ||(sz<2)) return 0;
+ for(; n--; p+=sz) for(h=(l=p)+sz; --h>l; l++) {
+ c=*h;
+ *h=*l;
+ *l=c;
+ }
+ return 0;
+ };
+
+ void fwritex(char *p,int sz,int n,FILE* f)
+ {
+
+ if(*(short *)"AB"==0x4241) {
+ swapbytes((char*)p, sz,n);
+ }
+
+ fwrite((char *)p,sz,n,f);
+
+ if(*(short *)"AB"==0x4241) swapbytes((char*)p, sz,n);
+
+ }
+
+ void ifwrite(long loc,void *ptr,int size,int /* unused parameter: n */,FILE* f)
+ {
+ fflush(f);
+
+ long pos=ftell(f);
+
+ fseek(f,loc,SEEK_SET);
+
+ fwritex((char *)ptr,size,1,f);
+
+ fseek(f,pos,SEEK_SET);
+
+ fflush(f);
+ }
+
+ void writeNull(unsigned short nullCode,float nullProb,FILE* f)
+ {
+ fwritex((char *)&nullCode,sizeof(short),1,f);
+ fwritex((char *)&nullProb,sizeof(float),1,f);
+ }
+
+
+ int mdiadaptlm::saveASR(char *filename,int /* unused parameter: backoff */,char* subdictfile)
+ {
+ int totbg,tottr;
+
+ dictionary* subdict;
+
+ if (subdictfile)
+ subdict=new dictionary(subdictfile);
+ else
+ subdict=dict; // default is subdict=dict
+
+ typedef unsigned short code;
+
+ system("date");
+
+ if (lmsize()>3 || lmsize()<1) {
+ cerr << "wrong lmsize\n";
+ exit(1);
+ }
+
+ if (dict->size()>=0xffff && subdict->size()>=0xffff) {
+ cerr << "save bin requires unsigned short codes\n";
+ exit(1);
+ }
+
+ FILE* f=fopen(filename,"w");
+
+ double fstar,lambda,boff;
+ float pr;
+ long succ1pos,succ2pos;
+ code succ1,succ2,w,h1,h2;
+ code stop=0xffff;
+
+ //dictionary
+ //#dictsize w1\n ..wN\n NULL\n
+
+ code oovcode=subdict->oovcode();
+
+ //includes at least NULL
+ code subdictsz=subdict->size()+1;
+
+ fwritex((char *)&subdictsz,sizeof(code),1,f);
+
+ subdictsz--;
+ for (w=0; w<subdictsz; w++)
+ fprintf(f,"%s\n",(char *)subdict->decode(w));
+
+ fprintf(f,"____\n");
+
+ //unigram part
+ //NULL #succ w1 pr1 ..wN prN
+
+ h1=subdictsz;
+ fwritex((char *)&h1,sizeof(code),1,f); //NULL
+
+ succ1=0;
+ succ1pos=ftell(f);
+ fwritex((char *)&succ1,sizeof(code),1,f);
+
+ ngram ng(dict);
+ ngram sng(subdict);
+
+ ng.size=sng.size=1;
+
+ scan(ng,INIT,1);
+ while(scan(ng,CONT,1)) {
+ sng.trans(ng);
+ if (sng.containsWord(subdict->OOV(),1))
+ continue;
+
+ pr=(float)mdiadaptlm::prob(ng,1);
+ if (pr>1e-50) { //do not consider too low probabilities
+ succ1++;
+ w=*sng.wordp(1);
+ fwritex((char *)&w,sizeof(code),1,f);
+ fwritex((char *)&pr,sizeof(float),1,f);
+ } else {
+ cerr << "small prob word " << ng << "\n";
+ }
+ }
+
+ // update number of unigrams
+ ifwrite(succ1pos,&succ1,sizeof(code),1,f);
+
+ cerr << "finito unigrammi " << succ1 << "\n";
+ fflush(f);
+
+ if (lmsize()==1) {
+ fclose(f);
+ return 1;
+ }
+
+ // rest of bigrams
+ // w1 #succ w1 pr1 .. wN prN
+
+ succ1=0;
+ h1=subdictsz;
+ totbg=subdictsz;
+
+ ngram hg1(dict,1);
+
+ ng.size=sng.size=2;
+
+ scan(hg1,INIT,1);
+ while(scan(hg1,CONT,1)) {
+
+ if (hg1.containsWord(dict->OOV(),1)) continue;
+
+ MY_ASSERT((*hg1.wordp(1))<dict->size());
+
+ *ng.wordp(2)=*hg1.wordp(1);
+ *ng.wordp(1)=0;
+
+ sng.trans(ng);
+ if (sng.containsWord(dict->OOV(),1)) continue;
+
+ mdiadaptlm::bodiscount(ng,2,fstar,lambda,boff);
+
+ if (lambda < 1.0) {
+
+ h1=*sng.wordp(2);
+
+ fwritex((char *)&h1,sizeof(code),1,f);
+
+ succ1=0;
+ succ1pos=ftell(f);
+ fwritex((char *)&succ1,sizeof(code),1,f);
+
+ ngram shg=hg1;
+ get(shg,1,1);
+
+ succscan(shg,ng,INIT,2);
+ while(succscan(shg,ng,CONT,2)) {
+
+ if (*ng.wordp(1)==oovcode) continue;
+
+ sng.trans(ng);
+ if (sng.containsWord(dict->OOV(),2)) continue;
+
+ mdiadaptlm::discount(ng,2,fstar,lambda);
+
+ if (fstar>1e-50) {
+ w=*sng.wordp(1);
+ fwritex((char *)&w,sizeof(code),1,f);
+ pr=(float)mdiadaptlm::prob(ng,2);
+ //cerr << ng << " prob=" << log(pr) << "\n";
+
+ fwritex((char *)&pr,sizeof(float),1,f);
+ succ1++;
+ }
+ }
+
+ if (succ1) {
+ lambda/=boff; //consider backoff
+ writeNull(subdictsz,(float)lambda,f);
+ succ1++;
+ totbg+=succ1;
+ ifwrite(succ1pos,&succ1,sizeof(code),1,f);
+ } else {
+ //go back one word
+ fseek(f,succ1pos-(streampos)sizeof(code),SEEK_SET);
+ }
+ }
+ }
+
+ fwritex((char *)&stop,sizeof(code),1,f);
+
+ cerr << " finito bigrammi! " << subdictsz << "\n";
+ fflush(f);
+
+ system("date");
+
+ if (lmsize()<3) {
+ fclose(f);
+ return 1;
+ }
+
+ //TRIGRAM PART
+
+ h1=subdictsz;
+ h2=subdictsz;
+ tottr=0;
+ succ1=0;
+ succ2=0;
+
+ ngram hg2(dict,2);
+
+ ng.size=sng.size=3;
+
+ scan(hg1,INIT,1);
+ while(scan(hg1,CONT,1)) {
+
+ if ((*hg1.wordp(1)==oovcode)) continue;
+
+ *ng.wordp(3)=*hg1.wordp(1);
+
+ sng.trans(ng);
+ if (sng.containsWord(dict->OOV(),1)) continue;
+
+ MY_ASSERT((*sng.wordp(3))<subdictsz);
+
+ h1=*sng.wordp(3);
+ fwritex((char *)&h1,sizeof(code),1,f);
+
+ succ1=0;
+ succ1pos=ftell(f);
+ fwritex((char *)&succ1,sizeof(code),1,f);
+
+ ngram shg1=ng;
+ get(shg1,3,1);
+
+ succscan(shg1,hg2,INIT,2);
+ while(succscan(shg1,hg2,CONT,2)) {
+
+ if (*hg2.wordp(1)==oovcode) continue;
+
+ *ng.wordp(2)=*hg2.wordp(1);
+ *ng.wordp(1)=0;
+
+ sng.trans(ng);
+ if (sng.containsWord(dict->OOV(),2)) continue;
+
+ mdiadaptlm::bodiscount(ng,3,fstar,lambda,boff);
+
+ if (lambda < 1.0) {
+
+ h2=*sng.wordp(2);
+ fwritex((char *)&h2,sizeof(code),1,f);
+
+ succ2=0;
+ succ2pos=ftell(f);
+ fwritex((char *)&succ2,sizeof(code),1,f);
+
+ ngram shg2=ng;
+ get(shg2,3,2);
+
+ succscan(shg2,ng,INIT,3);
+ while(succscan(shg2,ng,CONT,3)) {
+
+ if (*ng.wordp(1)==oovcode) continue;
+
+ sng.trans(ng);
+ if (sng.containsWord(dict->OOV(),3)) continue;
+
+ mdiadaptlm::discount(ng,3,fstar,lambda);
+ //pr=(float)mdiadaptlm::prob2(ng,3,fstar);
+
+ if (fstar>1e-50) {
+
+ w=*sng.wordp(1);
+ fwritex((char *)&w,sizeof(code),1,f);
+
+ pr=(float)mdiadaptlm::prob(ng,3);
+
+ // cerr << ng << " prob=" << log(pr) << "\n";
+ fwritex((char *)&pr,sizeof(float),1,f);
+ succ2++;
+ }
+ }
+
+ if (succ2) {
+ lambda/=boff;
+ writeNull(subdictsz,(float)lambda,f);
+ succ2++;
+ tottr+=succ2;
+ ifwrite(succ2pos,&succ2,sizeof(code),1,f);
+ succ1++;
+ } else {
+ //go back one word
+ fseek(f,succ2pos-(long)sizeof(code),SEEK_SET);
+ }
+ }
+ }
+
+ if (succ1)
+ ifwrite(succ1pos,&succ1,sizeof(code),1,f);
+ else
+ fseek(f,succ1pos-(long)sizeof(code),SEEK_SET);
+ }
+
+ fwritex((char *)&stop,sizeof(code),1,f);
+
+ fclose(f);
+
+ cerr << "Tot bg: " << totbg << " tg: " << tottr<< "\n";
+
+ system("date");
+
+ return 1;
+ };
+
+
+ ///// Save in IRST MT format
+
+ int mdiadaptlm::saveMT(char *filename,int backoff,
+ char* subdictfile,int resolution,double decay)
+ {
+
+ double logalpha=log(decay);
+ dictionary* subdict;
+
+ if (subdictfile)
+ subdict=new dictionary(subdictfile);
+ else
+ subdict=dict; // default is subdict=dict
+
+ ngram ng(dict,lmsize());
+ ngram sng(subdict,lmsize());
+
+ cerr << "Adding unigram of OOV word if missing\n";
+
+ for (int i=1; i<=maxlevel(); i++)
+ *ng.wordp(i)=dict->oovcode();
+
+ if (!get(ng,maxlevel(),1)) {
+ cerr << "oov is missing in the ngram-table\n";
+ // f(oov) = dictionary size (Witten Bell)
+ ng.freq=dict->freq(dict->oovcode());
+ cerr << "adding oov unigram " << ng << "\n";
+ put(ng);
+ }
+
+ cerr << "Eventually adding OOV symbol to subdictionary\n";
+ subdict->encode(OOV_);
+
+ system("date");
+
+ mfstream out(filename,ios::out);
+
+ //add special symbols
+
+ subdict->incflag(1);
+ int bo_code=subdict->encode(BACKOFF_);
+ int du_code=subdict->encode(DUMMY_);
+ subdict->incflag(0);
+
+ out << "nGrAm " << lmsize() << " " << 0
+ << " " << "LM_ "
+ << resolution << " "
+ << decay << "\n";
+
+ subdict->save(out);
+
+ //start writing ngrams
+
+ cerr << "write unigram of oov probability\n";
+ ng.size=1;
+ *ng.wordp(1)=dict->oovcode();
+ double pr=(float)mdiadaptlm::prob(ng,1);
+ sng.trans(ng);
+ sng.size=lmsize();
+ for (int s=2; s<=lmsize(); s++) *sng.wordp(s)=du_code;
+ sng.freq=(int)ceil(pr * (double)10000000)-1;
+ out << sng << "\n";
+
+ for (int i=1; i<=lmsize(); i++) {
+ cerr << "LEVEL " << i << "\n";
+
+ double fstar,lambda,bo,dummy;
+
+ scan(ng,INIT,i);
+ while(scan(ng,CONT,i)) {
+
+ sng.trans(ng);
+
+ sng.size=lmsize();
+ for (int s=i+1; s<=lmsize(); s++)
+ *sng.wordp(s)=du_code;
+
+ if (i>=1 && sng.containsWord(subdict->OOV(),sng.size)) {
+ cerr << "skipping : " << sng << "\n";
+ continue;
+ }
+
+ // skip also eos symbols not at the final
+ //if (i>=1 && sng.containsWord(dict->EoS(),sng.size))
+ //continue;
+
+ mdiadaptlm::discount(ng,i,fstar,dummy);
+
+ //out << sng << " fstar " << fstar << " lambda " << lambda << "\n";
+ //if (i==1 && sng.containsWord(subdict->OOV(),i)){
+ // cerr << sng << " fstar " << fstar << "\n";
+ //}
+
+ if (fstar>0) {
+
+ double pr=(float)mdiadaptlm::prob(ng,i);
+
+ if (i>1 && resolution<10000000) {
+ sng.freq=resolution-(int)(log(pr)/logalpha)-1;
+ sng.freq=(sng.freq>=0?sng.freq:0);
+ } else
+ sng.freq=(int)ceil(pr * (double)10000000)-1;
+
+ out << sng << "\n";
+
+ }
+
+ if (i<lmsize()) { /// write backoff of higher order!!
+
+ ngram ng2=ng;
+ ng2.pushc(0); //extend by one
+ mdiadaptlm::bodiscount(ng2,i+1,dummy,lambda,bo);
+ MY_ASSERT(!backoff || (lambda ==1 || bo<1 ));
+
+ sng.pushc(bo_code);
+ sng.size=lmsize();
+
+ if (lambda<1) {
+ if (resolution<10000000) {
+ sng.freq=resolution-(int)((log(lambda) - log(bo))/logalpha)-1;
+ sng.freq=(sng.freq>=0?sng.freq:0);
+ } else
+ sng.freq=(int)ceil(lambda/bo * (double)10000000)-1;
+
+ out << sng << "\n";
+ }
+ }
+ }
+ cerr << "LEVEL " << i << "DONE \n";
+ }
+ return 1;
+ };
+
+ ///// Save in binary format forbackoff N-gram models
+
+ int mdiadaptlm::saveBIN_per_word(char *filename,int backoff,char* subdictfile,int mmap)
+ {
+ VERBOSE(2,"mdiadaptlm::saveBIN_per_word START\n");
+ system("date");
+
+ //subdict
+ dictionary* subdict;
+
+ //accumulated unigram oov prob
+ //CHECK why this is not used (differently from what happens in the other save functions
+ // double oovprob=0;
+
+
+ if (subdictfile) subdict=new dictionary(subdictfile);
+ else subdict=dict; // default is subdict=dict
+
+ if (mmap) {
+ VERBOSE(2,"savebin with memory map: " << filename << "\n");
+ } else {
+ VERBOSE(2,"savebin: " << filename << "\n");
+ }
+
+ int maxlev=lmsize();
+ streampos pos[LMTMAXLEV+1];
+ char buff[100];
+ int isQuant=0; //savebin for quantized LM is not yet implemented
+
+ //temporary filename to save the LM related to a single term
+ char tmpfilename[BUFSIZ];
+
+ //create temporary output file stream to store single levels for all terms
+ MY_ASSERT(strlen(filename)<1000);
+ char tfilename[LMTMAXLEV+1][1000];
+ mfstream *tout[LMTMAXLEV+1];
+
+ tout[0]=NULL;
+ for (int i=1; i<=maxlev; i++) {
+ sprintf(tfilename[i],"%s-%dgrams",filename,i);
+ tout[i]=new mfstream(tfilename[i],ios::out);
+ }
+
+ // print header in the main output file
+ mfstream out(filename,ios::out);
+ out << "blmt " << maxlev;
+
+ for (int i=1; i<=maxlev; i++) { //reserve space for ngram statistics (which are not yet avalable)
+ pos[i]=out.tellp();
+ sprintf(buff," %10d",0);
+ out << buff;
+ }
+ out << "\n";
+ subdict->save(out);
+ out.flush();
+
+ ngram ng(dict,maxlev);
+ ngram oldng(dict,maxlev);
+ ngram locng(dict,maxlev);
+
+ ngram sng(subdict,maxlev);
+
+ double fstar,lambda,bo,dummy,dummy2,pr,ibow;
+
+ double oovprob=0.0; //accumulated unigram oov pro
+ bool _OOV_unigram=false; //flag to check whether an OOV word is present or not
+
+ //n-gram counters
+ table_entry_pos_t num[LMTMAXLEV+1];
+ for (int i=1; i<=maxlev; i++) num[i]=0;
+
+ lmtable* lmt = new lmtable();
+
+ lmt->configure(maxlev,isQuant);
+ lmt->setDict(subdict);
+ lmt->expand_level(1,dict->size(),filename,mmap);
+
+ //main loop
+ for (int w=0; w<dict->size(); w++) {
+ int i=1; //set the initial value of level
+ sprintf(tmpfilename,"%s_tmp_%d",filename,w);
+
+ if (!w % 10000) cerr << ".";
+
+ //1-gram
+ ngram ung(dict,1);
+ *ung.wordp(1)=w;
+ sng.trans(ung);
+
+
+ // frequency pruning is not applied to unigrams
+
+ /*
+ //exclude words not occurring in the subdictionary
+ if (sng.containsWord(subdict->OOV(),1) && !ung.containsWord(dict->OOV(),1)) continue;
+ */
+
+ pr=mdiadaptlm::prob(ung,1);
+
+ if (sng.containsWord(subdict->OOV(),1) || ung.containsWord(dict->OOV(),1)) {
+ _OOV_unigram=true;
+ oovprob+=pr; //accumulate oov probability
+ continue;
+ }
+ pr=(pr?log10(pr):-99);
+
+ if (i<maxlev) { //compute back-off
+ ung.pushc(0); //extend by one
+ VERBOSE(3,"mdiadaptlm::saveBIN_per_word(char *filename,int backoff,char* subdictfile ) computing backoff for ung:|" << ung << "| size:" << i+1 << std::endl);
+ mdiadaptlm::bodiscount(ung,i+1,dummy,lambda,bo);
+ VERBOSE(3,"mdiadaptlm::saveBIN_per_word(char *filename,int backoff,char* subdictfile ) getting backoff for ung:|" << ung << "| lambda:" << lambda << " bo:" << bo << std::endl);
+ ung.shift();//shrink by one
+
+ if (fstar<UPPER_SINGLE_PRECISION_OF_0 && lambda>LOWER_SINGLE_PRECISION_OF_1){ //ngram must be skipped
+ ibow = DONT_PRINT;
+ }else{
+ if (backoff){
+ ibow=log10(lambda) - log10(bo);
+ }else{
+ MY_ASSERT((lambda<UPPER_SINGLE_PRECISION_OF_1 && lambda>LOWER_SINGLE_PRECISION_OF_1) || bo<UPPER_SINGLE_PRECISION_OF_1 );
+ if (lambda<LOWER_SINGLE_PRECISION_OF_1){
+ ibow = log10(lambda);
+ }else { //force to be 0.0
+ ibow = 0.0;
+ }
+ }
+ }
+ }
+ else {
+ ibow=0.0; //default value for backoff weight at the lowest level
+ }
+
+ if (ibow != DONT_PRINT){
+ lmt->addwithoffset(ung,(float)pr,(float)ibow);
+ }
+ num[i]++;
+
+ //manage n-grams
+ if (get(ung,1,1)) {
+
+ //create n-gram with history w
+ *ng.wordp(lmsize())=w;
+
+ //create sentinel n-gram
+ for (int i=1; i<=maxlev; i++) *oldng.wordp(i)=-1;
+
+ //create the table for all levels but the level 1, with the maximum number of possible entries
+ for (int i=2; i<=maxlev; i++)
+ lmt->expand_level(i,entries(i),tmpfilename,mmap);
+
+ scan(ung.link,ung.info,1,ng,INIT,lmsize());
+ while(scan(ung.link,ung.info,1,ng,CONT,lmsize())) {
+ sng.trans(ng); // convert to subdictionary
+// locng=ng; // make a local copy
+
+ //find first internal level that changed
+ int f=maxlev-1; //unigrams have been already covered
+ while (f>1 && (*oldng.wordp(f)==*ng.wordp(f))){ f--; }
+
+ for (int l=maxlev-(f-1); l<=lmsize(); l++){
+
+ locng=ng; // make a local copy
+ if (l<lmsize()) locng.shift(maxlev-l); //reduce the ngram, which has size level
+
+ // frequency pruning: skip n-grams with low frequency
+ if (prune_ngram(l,sng.freq)) continue;
+
+ // skip n-grams containing OOV
+ if (sng.containsWord(subdict->OOV(),l)) continue;
+
+ // skip also n-grams containing eos symbols not at the final
+ if (sng.containsWord(dict->EoS(),l-1)) continue;
+
+ VERBOSE(3,"mdiadaptlm::saveBIN_per_word(char *filename,int backoff,char* subdictfile ) computing prob for locng:|" << locng << "| size:" << l << std::endl);
+ pr=mdiadaptlm::prob(locng,l,fstar,dummy,dummy2);
+ VERBOSE(3,"mdiadaptlm::saveBIN_per_word(char *filename,int backoff,char* subdictfile ) getting prob locng:|" << locng << "| size:" << l << " fstar:" << fstar << " pr:" << pr << std::endl);
+
+ //PATCH by Nicola (16-04-2008)
+
+ if (!(pr<=1.0 && pr > 1e-10)) {
+ cerr << ng << " " << pr << "\n";
+ MY_ASSERT(pr<=1.0);
+ cerr << "prob modified to 1e-10\n";
+ pr=1e-10;
+ }
+
+ if (l<lmsize()) {
+
+ locng.pushc(0); //extend by one
+
+ VERBOSE(3,"mdiadaptlm::saveBIN_per_word(char *filename,int backoff,char* subdictfile ) computing backoff for locng:|" << locng << "| size:" << l+1 << std::endl);
+ mdiadaptlm::bodiscount(locng,l+1,dummy,lambda,bo);
+ VERBOSE(3,"mdiadaptlm::saveBIN_per_word(char *filename,int backoff,char* subdictfile ) getting backoff locng:|" << locng << "| lambda:" << lambda << " bo:" << bo << std::endl);
+
+ locng.shift();
+
+ if (fstar<UPPER_SINGLE_PRECISION_OF_0 && lambda>LOWER_SINGLE_PRECISION_OF_1){ //ngram must be skipped
+ ibow = DONT_PRINT;
+ }else{
+ if (backoff){
+ ibow = (float) (log10(lambda) - log10(bo));
+ }else{
+ MY_ASSERT((lambda<UPPER_SINGLE_PRECISION_OF_1 && lambda>LOWER_SINGLE_PRECISION_OF_1) || bo<UPPER_SINGLE_PRECISION_OF_1 );
+ if (lambda<LOWER_SINGLE_PRECISION_OF_1){
+ ibow = log10(lambda);
+ }else{ //no output if log10(lambda)==0
+ ibow = 0.0;
+ }
+ }
+ }
+ } else { //i==maxlev
+ ibow = 0.0;
+ }
+
+ if (fstar>=UPPER_SINGLE_PRECISION_OF_0 || ibow!=DONT_PRINT ) {
+ if (lmt->addwithoffset(locng,(float)log10(pr),(float)ibow)){
+ num[l]++;
+ }else{
+ continue;
+ }
+ } else{
+ continue; //skip n-grams with too small fstar
+ }
+ }
+ oldng=ng;
+ }
+ }
+ else{
+ //create empty tables for all levels but the level 1, to keep consistency with the rest of the code
+ for (int i=2; i<=maxlev; i++)
+ lmt->expand_level(i,0,tmpfilename,mmap);
+ }
+
+
+ //level 1 is not modified until everything is done
+ //because it has to contain the full dictionary
+ //which provides the direct access to the second level
+ for (int i=2; i<=lmsize(); i++){
+
+ if (i>2) {
+ lmt->checkbounds(i-1);
+ lmt->appendbin_level(i-1, *tout[i-1], mmap);
+ }
+
+ // now we can resize table at level i
+ lmt->resize_level(i, tmpfilename, mmap);
+ }
+
+ // now we can save table at level maxlev, if not equal to 1
+ if (maxlev>1){
+ lmt->appendbin_level(maxlev, *tout[maxlev], mmap);
+ }
+
+ //delete levels from 2 to lmsize();
+ for (int i=2; i<=maxlev; i++) lmt->delete_level(i, tmpfilename, mmap);
+
+ //update table offsets
+ for (int i=2; i<=maxlev; i++) lmt->update_offset(i,num[i]);
+ }
+
+ if (_OOV_unigram){
+ ngram ung(dict,1);
+ *ung.wordp(1)=dict->oovcode();
+ pr=oovprob;
+ ibow=0.0;
+ lmt->addwithoffset(ung,(float)pr,(float)ibow);
+ num[1]++;
+ }
+
+ //close levels from 2 to lmsize()
+ for (int i=2; i<=maxlev; i++) tout[i]->close();
+
+ //now we can save level 1, which contains all unigrams
+ //cerr << "saving level 1" << "...\n";
+ lmt->savebin_level(1, filename, mmap);
+
+ //update headers
+ for (int i=1; i<=maxlev; i++) {
+ sprintf(buff," %10d",num[i]);
+ out.seekp(pos[i]);
+ out << buff;
+ }
+
+ out.close();
+
+ //concatenate files for each single level into one file
+ //single level files should have a name derived from "filename"
+ lmt->compact_all_levels(filename);
+
+ cerr << "\n";
+ system("date");
+
+ VERBOSE(2,"mdiadaptlm::saveBIN_per_word END\n");
+ return 1;
+ };
+
+ ///// Save in binary format forbackoff N-gram models
+ int mdiadaptlm::saveBIN_per_level(char *filename,int backoff,char* subdictfile,int mmap)
+ {
+ VERBOSE(2,"mdiadaptlm::saveBIN_per_level START\n");
+ system("date");
+
+ //subdict
+ dictionary* subdict;
+
+ if (subdictfile) subdict=new dictionary(subdictfile);
+ else subdict=dict; // default is subdict=dict
+
+ if (mmap) {
+ VERBOSE(2,"savebin with memory map: " << filename << "\n");
+ } else {
+ VERBOSE(2,"savebin: " << filename << "\n");
+ }
+
+ int maxlev=lmsize();
+ streampos pos[LMTMAXLEV+1];
+ char buff[100];
+ int isQuant=0; //savebin for quantized LM is not yet implemented
+
+ // print header
+ fstream out(filename,ios::out);
+ out << "blmt " << maxlev;
+
+ for (int i=1; i<=maxlev; i++) { //reserve space for ngram statistics (which are not yet avalable)
+ pos[i]=out.tellp();
+ sprintf(buff," %10d",0);
+ out << buff;
+ }
+ out << "\n";
+ lmtable* lmt = new lmtable();
+
+ lmt->configure(maxlev,isQuant);
+
+ lmt->setDict(subdict);
+ subdict->save(out);
+ out.flush();
+
+
+ //start adding n-grams to lmtable
+
+ for (int i=1; i<=maxlev; i++) {
+ cerr << "saving level " << i << "...\n";
+ table_entry_pos_t numberofentries;
+ if (i==1) { //unigram
+ numberofentries = (table_entry_pos_t) subdict->size();
+ } else {
+ numberofentries = (table_entry_pos_t) entries(i);
+ }
+ system("date");
+ lmt->expand_level(i,numberofentries,filename,mmap);
+
+ double fstar,lambda,bo,dummy,dummy2,pr,ibow;
+
+ ngram ng(dict,1);
+ ngram ng2(dict);
+ ngram sng(subdict,1);
+
+ if (i==1) { //unigram case
+
+ double oovprob=0.0; //accumulated unigram oov pro
+ bool _OOV_unigram=false; //flag to check whether an OOV word is present or not
+
+ //scan the dictionary
+ for (int w=0; w<dict->size(); w++) {
+ *ng.wordp(1)=w;
+
+ sng.trans(ng);
+
+ // frequency pruning is not applied to unigrams
+
+ pr=mdiadaptlm::prob(ng,i);
+
+ if (sng.containsWord(subdict->OOV(),i) || ng.containsWord(dict->OOV(),i)) {
+ _OOV_unigram=true;
+ oovprob+=pr; //accumulate oov probability
+ continue;
+ }
+
+ /*
+ if (sng.containsWord(subdict->OOV(),i) && !ng.containsWord(dict->OOV(),i)) {
+ oovprob+=pr; //accumulate oov probability
+ continue;
+ }
+
+ if (ng.containsWord(dict->OOV(),i)) pr+=oovprob;
+ */
+
+ //cerr << ng << " freq " << dict->freq(w) << " - Pr " << pr << "\n";
+ pr=(pr?log10(pr):-99);
+
+ /*
+ if (w==dict->oovcode()){
+ //CHECK whether we can avoid this reassignment because dict should be lmt->getDict()
+ *ng.wordp(1)=lmt->getDict()->oovcode();
+ ibow=0.0;
+ }
+ else {
+ // } //do nothing
+ */
+ if (i<maxlev) {
+ ngram ng2=ng;
+ ng2.pushc(0); //extend by one
+
+ //cerr << ng2 << "\n";
+
+ VERBOSE(3,"mdiadaptlm::saveBIN_per_level(char *filename,int backoff,char* subdictfile ) computing backoff for ng2:|" << ng2 << "| size:" << i+1 << std::endl);
+ mdiadaptlm::bodiscount(ng2,i+1,dummy,lambda,bo);
+ VERBOSE(3,"mdiadaptlm::saveBIN_per_level(char *filename,int backoff,char* subdictfile ) getting backoff for ng2:|" << ng2 << "| lambda:" << lambda << " bo:" << bo << std::endl);
+
+ if (fstar<UPPER_SINGLE_PRECISION_OF_0 && lambda>LOWER_SINGLE_PRECISION_OF_1){ //ngram must be skipped
+ ibow = DONT_PRINT;
+ }else{
+ if (backoff){
+ ibow = log10(lambda) - log10(bo);
+ }
+ else{
+ MY_ASSERT((lambda<UPPER_SINGLE_PRECISION_OF_1 && lambda>LOWER_SINGLE_PRECISION_OF_1) || bo<UPPER_SINGLE_PRECISION_OF_1 );
+ if (lambda<LOWER_SINGLE_PRECISION_OF_1){
+ ibow = log10(lambda);
+ }else { //force to be 0.0
+ ibow = 0.0;
+ }
+ }
+ }
+ }else { //i==maxlev
+ ibow=0.0; //default value for backoff weight at the highest level
+ }
+ VERBOSE(3,"mdiadaptlm::saveARPA_per_level(char *filename,int backoff,char* subdictfile ) writing w:|" << (char *)dict->decode(w) << "| pr:" << pr << " ibow:" << ibow << std::endl);
+ if (ibow != DONT_PRINT ) {
+ lmt->add(ng,(float)log10(pr),(float)ibow);
+ }
+ }
+ //add unigram with OOV and its accumulate oov probability
+ if (_OOV_unigram){
+ *ng.wordp(1)=lmt->getDict()->oovcode();
+ ibow=0.0;
+ pr=oovprob;
+ pr=(pr?log10(pr):-99);
+ lmt->add(ng,(float)pr,(float)ibow);
+ }
+ }
+ else { //i>1 , bigrams, trigrams, fourgrams...
+ *ng.wordp(1)=0;
+ get(ng,1,1); //this
+ scan(ng,INIT,i);
+ while(scan(ng,CONT,i)) {
+ sng.trans(ng);
+
+ // frequency pruning: skip n-grams with low frequency
+ if (prune_ngram(i,sng.freq)) continue;
+
+ // skip n-grams containing OOV
+ if (sng.containsWord(subdict->OOV(),i)) continue;
+
+ // skip also n-grams containing eos symbols not at the final
+ if (sng.containsWord(dict->EoS(),i-1)) continue;
+
+ // mdiadaptlm::discount(ng,i,fstar,dummy);
+ // pr=mdiadaptlm::prob(ng,i);
+ pr=mdiadaptlm::prob(ng,i,fstar,dummy,dummy2);
+
+ if (!(pr<=1.0 && pr > 1e-10)) {
+ cerr << ng << " " << pr << "\n";
+ MY_ASSERT(pr<=1.0);
+ cerr << "prob modified to 1e-10\n";
+ pr=1e-10;
+ }
+
+ if (i<maxlev) {
+ ng2=ng;
+ ng2.pushc(0); //extend by one
+ VERBOSE(3,"mdiadaptlm::saveBIN_per_level(char *filename,int backoff,char* subdictfile ) computing backoff for ng2:|" << ng2 << "| size:" << i+1 << std::endl);
+ mdiadaptlm::bodiscount(ng2,i+1,dummy,lambda,bo);
+ VERBOSE(3,"mdiadaptlm::saveBIN_per_level(char *filename,int backoff,char* subdictfile ) getting backoff for ng2:|" << ng2 << "| lambda:" << lambda << " bo:" << bo << std::endl);
+
+ if (fstar<UPPER_SINGLE_PRECISION_OF_0 && lambda>LOWER_SINGLE_PRECISION_OF_1){ //ngram must be skipped
+ ibow=DONT_PRINT;
+ }else{
+ if (backoff){
+ ibow=log10(lambda) - log10(bo);
+ }else{
+ MY_ASSERT((lambda<UPPER_SINGLE_PRECISION_OF_1 && lambda>LOWER_SINGLE_PRECISION_OF_1) || bo<UPPER_SINGLE_PRECISION_OF_1 );
+ if (lambda<LOWER_SINGLE_PRECISION_OF_1){
+ ibow=log10(lambda);
+ }else{ //force ibow to log10(lambda)==0.0
+ ibow=0.0;
+ }
+ }
+ }
+ } else { //i==maxlev
+ ibow=0.0; //value for backoff weight at the highest level
+ }
+ VERBOSE(3,"mdiadaptlm::saveBIN_per_level(char *filename,int backoff,char* subdictfile ) writing ng:|" << ng << "| pr:" << pr << " ibow:" << ibow << std::endl);
+ if (ibow != DONT_PRINT ) {
+ lmt->add(ng,(float)log10(pr),(float)ibow);
+ }
+ }
+ }
+
+ // now we can fix table at level i-1
+ // now we can save table at level i-1
+ // now we can remove table at level i-1
+ if (maxlev>1 && i>1) {
+ lmt->checkbounds(i-1);
+ lmt->savebin_level(i-1, filename, mmap);
+ }
+
+ // now we can resize table at level i
+ lmt->resize_level(i, filename, mmap);
+
+ }
+ // now we can save table at level maxlev
+ lmt->savebin_level(maxlev, filename, mmap);
+
+ //update headers
+ for (int i=1; i<=maxlev; i++) {
+ sprintf(buff," %10d",lmt->getCurrentSize(i));
+ out.seekp(pos[i]);
+ out << buff;
+ }
+ out.close();
+
+ //concatenate files for each single level into one file
+ //single level files should have a name derived from "filename"
+ lmt->compact_all_levels(filename);
+
+ VERBOSE(2,"mdiadaptlm::saveBIN_per_level END\n");
+ return 1;
+ }
+
+
+ ///// Save in format for ARPA backoff N-gram models
+ int mdiadaptlm::saveARPA_per_word(char *filename,int backoff,char* subdictfile )
+ {
+ VERBOSE(2,"mdiadaptlm::saveARPA_per_word START\n");
+ system("date");
+
+ //subdict
+ dictionary* subdict;
+
+
+ if (subdictfile) subdict=new dictionary(subdictfile);
+ else subdict=dict; // default is subdict=dict
+
+ //main output file
+ mfstream out(filename,ios::out);
+
+ int maxlev=lmsize();
+ //create temporary output file stream
+ MY_ASSERT(strlen(filename)<1000);
+ char tfilename[LMTMAXLEV+1][1000];
+ mfstream *tout[LMTMAXLEV+1];
+
+ tout[0]=NULL;
+ for (int i=1; i<=maxlev; i++) {
+ sprintf(tfilename[i],"%s.%d",filename,i);
+ tout[i]=new mfstream(tfilename[i],ios::out);
+ *tout[i] << "\n\\" << i << "-grams:\n";
+ }
+
+ ngram ng(dict,lmsize());
+ ngram oldng(dict,lmsize());
+ ngram locng(dict,lmsize());
+
+ ngram sng(subdict,lmsize());
+
+ double fstar,lambda,bo,dummy,dummy2,pr,outLambda;
+
+ double oovprob=0.0; //accumulated unigram oov pro
+ bool _OOV_unigram=false; //flag to check whether an OOV word is present or not
+
+ //n-gram counters
+ table_entry_pos_t num[LMTMAXLEV+1];
+ for (int i=1; i<=maxlev; i++) num[i]=0;
+
+
+ //main loop
+ for (int w=0; w<dict->size(); w++) {
+ int i=1; //set the initial value of level
+ if (!w % 10000) cerr << ".";
+
+ //1-gram
+ ngram ung(dict,1);
+ *ung.wordp(1)=w;
+ sng.trans(ung);
+
+ // frequency pruning is not applied to unigrams
+
+ /*
+ //exclude words not occurring in the subdictionary
+ if (sng.containsWord(subdict->OOV(),1) && !ung.containsWord(dict->OOV(),1)) continue;
+ */
+
+ pr=mdiadaptlm::prob(ung,1);
+ pr=(pr?log10(pr):-99);
+
+ //////CHECK
+ if (sng.containsWord(subdict->OOV(),1) || ung.containsWord(dict->OOV(),1)) {
+ _OOV_unigram=true;
+ oovprob+=pr; //accumulate oov probability
+ continue;
+ }
+
+ if (i<maxlev) { //print back-off
+ ung.pushc(0); //extend by one
+ VERBOSE(3,"mdiadaptlm::saveARPA_per_word(char *filename,int backoff,char* subdictfile ) computing backoff for ung:|" << ung << "| size:" << i+1 << std::endl);
+ mdiadaptlm::bodiscount(ung,i+1,dummy,lambda,bo);
+ VERBOSE(3,"mdiadaptlm::saveARPA_per_word(char *filename,int backoff,char* subdictfile ) getting backoff for ung:|" << ung << "| lambda:" << lambda << " bo:" << bo << std::endl);
+
+ ung.shift();//shrink by one
+ if (fstar<UPPER_SINGLE_PRECISION_OF_0 && lambda>LOWER_SINGLE_PRECISION_OF_1){ //ngram must be skipped
+ outLambda = DONT_PRINT;
+ }else{
+ if (backoff){
+ outLambda = (float) (log10(lambda) - log10(bo));
+ }
+ else{
+ MY_ASSERT((lambda<UPPER_SINGLE_PRECISION_OF_1 && lambda>LOWER_SINGLE_PRECISION_OF_1) || bo<UPPER_SINGLE_PRECISION_OF_1 );
+ if (lambda<LOWER_SINGLE_PRECISION_OF_1){
+ outLambda = (float) log10(lambda);
+ }
+ else {
+ outLambda = DONT_PRINT;
+ }
+ }
+ }
+ }else { //i==maxlev
+ outLambda = DONT_PRINT;
+ }
+
+ //cerr << ng << " freq " << dict->freq(w) << " - Pr " << pr << "\n";
+ *tout[i] << (float) (pr?log10(pr):-99);
+ *tout[i] << "\t" << (char *)dict->decode(w);
+ if (outLambda != DONT_PRINT){
+ *tout[i] << "\t" << outLambda;
+ }
+ *tout[i] << "\n";
+ num[i]++;
+
+ //manage n-grams
+ if (get(ung,1,1)) {
+
+ //create n-gram with history w
+ *ng.wordp(maxlev)=w;
+
+ //create sentinel n-gram
+ for (i=1; i<=maxlev; i++) *oldng.wordp(i)=-1;
+
+ scan(ung.link,ung.info,1,ng,INIT,maxlev);
+ while(scan(ung.link,ung.info,1,ng,CONT,maxlev)) {
+ //cerr << ng << "\n";
+ sng.trans(ng); // convert to subdictionary
+ locng=ng; // make a local copy
+
+ //find first internal level that changed
+ int f=maxlev-1; //unigrams have been already covered
+ while (f>1 && (*oldng.wordp(f)==*ng.wordp(f))){ f--; }
+
+ for (int l=maxlev; l>maxlev-f;l--){
+
+ if (l<maxlev) locng.shift(); //ngram has size level
+
+ // frequency pruning: skip n-grams with low frequency
+ if (prune_ngram(l,sng.freq)) continue;
+
+ // skip n-grams containing OOV
+ if (sng.containsWord(subdict->OOV(),l)) continue;
+
+ // skip also n-grams containing eos symbols not at the final
+ if (sng.containsWord(dict->EoS(),l-1)) continue;
+ VERBOSE(3,"mdiadaptlm::saveARPA_per_word(char *filename,int backoff,char* subdictfile ) computing prob for locng:|" << locng << "| size:" << i << std::endl);
+ pr=mdiadaptlm::prob(locng,l,fstar,dummy,dummy2);
+ VERBOSE(3,"mdiadaptlm::saveARPA_per_word(char *filename,int backoff,char* subdictfile ) getting prob locng:|" << locng << "| size:" << i << " fstar:" << fstar << " pr:" << pr << std::endl);
+
+ //PATCH by Nicola (16-04-2008)
+
+ if (!(pr<=1.0 && pr > 1e-10)) {
+ cerr << ng << " " << pr << "\n";
+ MY_ASSERT(pr<=1.0);
+ cerr << "prob modified to 1e-10\n";
+ pr=1e-10;
+ }
+
+ if (l<maxlev) {
+
+ locng.pushc(0); //extend by one
+ VERBOSE(3,"mdiadaptlm::saveARPA_per_word(char *filename,int backoff,char* subdictfile ) computing backoff for locng:|" << locng << "| size:" << l+1 << std::endl);
+ mdiadaptlm::bodiscount(locng,l+1,dummy,lambda,bo);
+ VERBOSE(3,"mdiadaptlm::saveARPA_per_word(char *filename,int backoff,char* subdictfile ) getting backoff locng:|" << locng << "| lambda:" << lambda << " bo:" << bo << std::endl);
+
+ locng.shift();
+ if (fstar<UPPER_SINGLE_PRECISION_OF_0 && lambda>LOWER_SINGLE_PRECISION_OF_1){ //ngram must be skipped
+ outLambda = DONT_PRINT;
+ }else{
+ if (backoff){
+ outLambda = (float) (log10(lambda) - log10(bo));
+ }else{
+ MY_ASSERT((lambda<UPPER_SINGLE_PRECISION_OF_1 && lambda>LOWER_SINGLE_PRECISION_OF_1) || bo<UPPER_SINGLE_PRECISION_OF_1 );
+ if (lambda<LOWER_SINGLE_PRECISION_OF_1){
+ outLambda = (float) log10(lambda);
+ }else{ //no output if log10(lambda)==0
+ outLambda = DONT_PRINT;
+ }
+ }
+ }
+ } else { //i==maxlev
+ outLambda = DONT_PRINT;
+ }
+
+ if (fstar>=UPPER_SINGLE_PRECISION_OF_0 || outLambda!=DONT_PRINT ) {
+ *tout[l] << (float) log10(pr);
+ *tout[l] << "\t" << (char *)dict->decode(*ng.wordp(i));
+ for (int j=i-1; j>0; j--)
+ *tout[l] << " " << (char *)dict->decode(*ng.wordp(j));
+ if (outLambda != DONT_PRINT){
+ *tout[l] << "\t" << outLambda;
+ }
+ *tout[l] << "\n";
+ num[l]++;
+ } else{
+ continue; //skip n-grams with too small fstar
+ }
+ }
+ oldng=ng;
+ }
+ }
+ }
+ if (_OOV_unigram){
+ pr=oovprob;
+ num[1]++;
+ out << (float) (pr?log10(pr):-99);
+ out << "\t" << "<unk>\n";
+ }
+
+ //print header
+ out << "\n\\data\\" << "\n";
+ char buff[100];
+ for (int i=1; i<=maxlev; i++) {
+ sprintf(buff,"ngram %2d=%10d\n",i,num[i]);
+ out << buff;
+ }
+ out << "\n";
+
+ //append and remove temporary files
+ for (int i=1; i<=maxlev; i++) {
+ delete tout[i];
+ tout[i]=new mfstream(tfilename[i],ios::in);
+ out << tout[i]->rdbuf();
+ delete tout[i];
+ removefile(tfilename[i]);
+ }
+
+ out << "\\end\\" << "\n";
+
+ cerr << "\n";
+ system("date");
+
+ VERBOSE(2,"mdiadaptlm::saveARPA_per_word END\n");
+ return 1;
+ };
+
+ ///// Save in format for ARPA backoff N-gram models
+ int mdiadaptlm::saveARPA_per_level(char *filename,int backoff,char* subdictfile )
+ {
+ VERBOSE(2,"mdiadaptlm::saveARPA_per_level START\n");
+ system("date");
+
+ //subdict
+ dictionary* subdict;
+
+ if (subdictfile) {
+ subdict=new dictionary(subdictfile);
+ } else
+ subdict=dict; // default is subdict=dict
+
+ fstream out(filename,ios::out);
+ // out.precision(15);
+
+ int maxlev = lmsize();
+ streampos pos[LMTMAXLEV+1];
+ table_entry_pos_t num[LMTMAXLEV+1];
+ char buff[100];
+
+ //print header
+ out << "\n\\data\\" << "\n";
+
+ for (int i=1; i<=maxlev; i++) {
+ num[i]=0;
+ pos[i]=out.tellp();
+ sprintf(buff,"ngram %2d=%10d\n",i,num[i]);
+ out << buff;
+ }
+
+ out << "\n";
+
+ //start writing n-grams
+
+ for (int i=1; i<=maxlev; i++) {
+ cerr << "saving level " << i << "...\n";
+
+
+ out << "\n\\" << i << "-grams:\n";
+
+ double fstar,lambda,bo,dummy,dummy2,pr,outLambda;
+
+ ngram ng(dict,1);
+ ngram ng2(dict);
+ ngram sng(subdict,1);
+
+ if (i==1) { //unigram case
+
+ double oovprob=0.0; //accumulated unigram oov pro
+ bool _OOV_unigram=false; //flag to check whether an OOV word is present or not
+
+ //scan the dictionary
+
+ for (int w=0; w<dict->size(); w++) {
+ *ng.wordp(1)=w;
+
+ sng.trans(ng);
+
+ // frequency pruning is not applied to unigrams
+ VERBOSE(3,"mdiadaptlm::saveARPA_per_level(char *filename,int backoff,char* subdictfile ) computing prob for ng:|" << ng << "| size:" << i << std::endl);
+ pr=mdiadaptlm::prob(ng,i);
+ VERBOSE(3,"mdiadaptlm::saveARPA_per_level(char *filename,int backoff,char* subdictfile ) getting prob for ng:|" << ng << "| pr:" << pr << std::endl);
+
+ if (sng.containsWord(subdict->OOV(),i) || ng.containsWord(dict->OOV(),i)) {
+ _OOV_unigram=true;
+ oovprob+=pr; //accumulate oov probability
+ continue;
+ }
+
+ /*
+ if (sng.containsWord(subdict->OOV(),i) && !ng.containsWord(dict->OOV(),i)) {
+ oovprob+=pr; //accumulate oov probability
+ continue;
+ }
+
+ if (ng.containsWord(dict->OOV(),i)) pr+=oovprob;
+ */
+
+ if (i<maxlev) {
+ ngram ng2=ng;
+ ng2.pushc(0); //extend by one
+
+ VERBOSE(3,"mdiadaptlm::saveARPA_per_level(char *filename,int backoff,char* subdictfile ) computing backoff for ng2:|" << ng2 << "| size:" << i+1 << std::endl);
+ mdiadaptlm::bodiscount(ng2,i+1,dummy,lambda,bo);
+ VERBOSE(3,"mdiadaptlm::saveARPA_per_level(char *filename,int backoff,char* subdictfile ) getting backoff for ng2:|" << ng2 << "| lambda:" << lambda << " bo:" << bo << std::endl);
+ if (fstar<UPPER_SINGLE_PRECISION_OF_0 && lambda>LOWER_SINGLE_PRECISION_OF_1){ //ngram must be skipped
+ outLambda = DONT_PRINT;
+ }else{
+ if (backoff){
+ outLambda = (float) (log10(lambda) - log10(bo));
+ }
+ else{
+ MY_ASSERT((lambda<UPPER_SINGLE_PRECISION_OF_1 && lambda>LOWER_SINGLE_PRECISION_OF_1) || bo<UPPER_SINGLE_PRECISION_OF_1 );
+ if (lambda<LOWER_SINGLE_PRECISION_OF_1){
+ outLambda = (float) log10(lambda);
+ }
+ else { //force to be 0.0 and hence to not output lambda
+ outLambda = DONT_PRINT;
+ }
+ }
+ }
+ }else { //i==maxlev
+ outLambda = DONT_PRINT;
+ }
+
+ VERBOSE(3,"mdiadaptlm::saveARPA_per_level(char *filename,int backoff,char* subdictfile ) writing w:|" << (char *)dict->decode(w) << "| pr:" << pr << " outLambda:" << outLambda << std::endl);
+ //cerr << ng << " freq " << dict->freq(w) << " - Pr " << pr << "\n";
+ out << (float) (pr?log10(pr):-99);
+ out << "\t" << (char *)dict->decode(w);
+ if (outLambda != DONT_PRINT){
+ out << "\t" << outLambda;
+ }
+ out << "\n";
+
+ num[i]++;
+ }
+
+
+ //add unigram with OOV and its accumulate oov probability
+ if (_OOV_unigram){
+ pr=oovprob;
+ num[i]++;
+ out << (float) (pr?log10(pr):-99);
+ out << "\t" << "<unk>\n";
+ }
+ }
+ else { //i>1 , bigrams, trigrams, fourgrams...
+ *ng.wordp(1)=0;
+ get(ng,1,1); //this
+ scan(ng,INIT,i);
+ while(scan(ng,CONT,i)) {
+
+ sng.trans(ng);
+
+ // frequency pruning: skip n-grams with low frequency
+ if (prune_ngram(i,sng.freq)) continue;
+
+ // skip n-grams containing OOV
+ if (sng.containsWord(subdict->OOV(),i)) continue;
+
+ // skip also n-grams containing eos symbols not at the final
+ if (sng.containsWord(dict->EoS(),i-1)) continue;
+
+ VERBOSE(3,"mdiadaptlm::saveARPA_per_level(char *filename,int backoff,char* subdictfile ) computing prob for ng:|" << ng << "| size:" << i << std::endl);
+ pr=mdiadaptlm::prob(ng,i,fstar,dummy,dummy2);
+ VERBOSE(3,"mdiadaptlm::saveARPA_per_level(char *filename,int backoff,char* subdictfile ) getting prob ng:|" << ng << "| size:" << i << " fstar:" << fstar << " pr:" << pr << std::endl);
+
+ //PATCH by Nicola (16-04-2008)
+ if (!(pr<=1.0 && pr > 1e-10)) {
+ cerr << ng << " " << pr << "\n";
+ MY_ASSERT(pr<=1.0);
+ cerr << "prob modified to 1e-10\n";
+ pr=1e-10;
+ }
+
+ if (i<maxlev) {
+ ng2=ng;
+ ng2.pushc(0); //extend by one
+
+ VERBOSE(3,"mdiadaptlm::saveARPA_per_level(char *filename,int backoff,char* subdictfile ) computing backoff for ng2:|" << ng2 << "| size:" << i+1 << std::endl);
+ mdiadaptlm::bodiscount(ng2,i+1,dummy,lambda,bo);
+ VERBOSE(3,"mdiadaptlm::saveARPA_per_level(char *filename,int backoff,char* subdictfile ) getting backoff for ng2:|" << ng2 << "| lambda:" << lambda << " bo:" << bo << std::endl);
+
+ if (fstar<UPPER_SINGLE_PRECISION_OF_0 && lambda>LOWER_SINGLE_PRECISION_OF_1){ //ngram must be skipped
+ outLambda = DONT_PRINT;
+ }else{
+ if (backoff){
+ outLambda = (float) (log10(lambda) - log10(bo));
+ }else{
+ MY_ASSERT((lambda<UPPER_SINGLE_PRECISION_OF_1 && lambda>LOWER_SINGLE_PRECISION_OF_1) || bo<UPPER_SINGLE_PRECISION_OF_1 );
+ if (lambda<LOWER_SINGLE_PRECISION_OF_1){
+ outLambda = (float) log10(lambda);
+ }else{ //no output of lambda if log10(lambda)==0
+ outLambda = DONT_PRINT;
+ }
+ }
+ }
+ } else { //i==maxlev
+ outLambda = DONT_PRINT;
+ }
+
+ VERBOSE(3,"mdiadaptlm::saveARPA_per_level(char *filename,int backoff,char* subdictfile ) writing ng:|" << ng << "| pr:" << pr << " outLambda:" << outLambda << std::endl);
+ if (fstar>=UPPER_SINGLE_PRECISION_OF_0 || outLambda!=DONT_PRINT ) {
+ out << (float) log10(pr);
+ out << "\t" << (char *)dict->decode(*ng.wordp(i));
+ for (int j=i-1; j>0; j--)
+ out << " " << (char *)dict->decode(*ng.wordp(j));
+ if (outLambda != DONT_PRINT){
+ out << "\t" << outLambda;
+ }
+ out << "\n";
+ num[i]++;
+ }
+ }
+ }
+
+ cerr << i << "grams tot:" << num[i] << "\n";
+ }
+
+ streampos last=out.tellp();
+
+ //update headers
+ for (int i=1; i<=maxlev; i++) {
+ sprintf(buff,"ngram %2d=%10u\n",i,num[i]);
+ out.seekp(pos[i]);
+ out << buff;
+ }
+
+ out.seekp(last);
+ out << "\\end\\" << "\n";
+ system("date");
+
+ VERBOSE(2,"mdiadaptlm::saveARPA_per_level END\n");
+ return 1;
+ };
+
+}//namespace irstlm
+
+/*
+ main(int argc,char** argv){
+ char* dictname=argv[1];
+ char* backngram=argv[2];
+ int depth=atoi(argv[3]);
+ char* forengram=argv[4];
+ char* testngram=argv[5];
+
+ dictionary dict(dictname);
+ ngramtable test(&dict,testngram,depth);
+
+ shiftbeta lm2(&dict,backngram,depth);
+ lm2.train();
+ //lm2.test(test,depth);
+
+ mdi lm(&dict,backngram,depth);
+ lm.train();
+ for (double w=0.0;w<=1.0;w+=0.1){
+ lm.getforelm(forengram);
+ lm.adapt(w);
+ lm.test(test,depth);
+ }
+ }
+ */
+
diff --git a/src/mdiadapt.h b/src/mdiadapt.h
new file mode 100644
index 0000000..fb95a5a
--- /dev/null
+++ b/src/mdiadapt.h
@@ -0,0 +1,158 @@
+/******************************************************************************
+IrstLM: IRST Language Model Toolkit
+Copyright (C) 2006 Marcello Federico, ITC-irst Trento, Italy
+
+This library is free software; you can redistribute it and/or
+modify it under the terms of the GNU Lesser General Public
+License as published by the Free Software Foundation; either
+version 2.1 of the License, or (at your option) any later version.
+
+This library is distributed in the hope that it will be useful,
+but WITHOUT ANY WARRANTY; without even the implied warranty of
+MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+Lesser General Public License for more details.
+
+You should have received a copy of the GNU Lesser General Public
+License along with this library; if not, write to the Free Software
+Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+
+******************************************************************************/
+
+// Adapted LM classes: extension of interp classes
+
+#ifndef MF_MDIADAPTLM_H
+#define MF_MDIADAPTLM_H
+
+#include "ngramcache.h"
+#include "normcache.h"
+#include "interplm.h"
+
+#define DONT_PRINT 1000000
+
+namespace irstlm {
+class mdiadaptlm:public interplm
+{
+
+ int adaptlev;
+ interplm* forelm;
+ double zeta0;
+ double oovscaling;
+ bool m_save_per_level;
+
+ static bool mdiadaptlm_cache_enable;
+
+protected:
+ normcache *cache;
+
+//to improve access speed
+ NGRAMCACHE_t** probcache;
+ NGRAMCACHE_t** backoffcache;
+ int max_caching_level;
+
+ int saveARPA_per_word(char *filename,int backoff=0,char* subdictfile=NULL);
+ int saveARPA_per_level(char *filename,int backoff=0,char* subdictfile=NULL);
+ int saveBIN_per_word(char *filename,int backoff=0,char* subdictfile=NULL,int mmap=0);
+ int saveBIN_per_level(char *filename,int backoff=0,char* subdictfile=NULL,int mmap=0);
+public:
+
+ mdiadaptlm(char* ngtfile,int depth=0,TABLETYPE tt=FULL);
+
+ inline normcache* get_zetacache() {
+ return cache;
+ }
+ inline NGRAMCACHE_t* get_probcache(int level);
+ inline NGRAMCACHE_t* get_backoffcache(int level);
+
+ void create_caches(int mcl);
+ void init_caches();
+ void init_caches(int level);
+ void delete_caches();
+ void delete_caches(int level);
+
+ void check_cache_levels();
+ void check_cache_levels(int level);
+ void reset_caches();
+ void reset_caches(int level);
+
+ void caches_stat();
+
+ double gis_step;
+
+ double zeta(ngram ng,int size);
+
+ int discount(ngram ng,int size,double& fstar,double& lambda,int cv=0);
+
+ int bodiscount(ngram ng,int size,double& fstar,double& lambda,double& bo);
+
+ virtual int compute_backoff();
+ virtual int compute_backoff_per_level();
+ virtual int compute_backoff_per_word();
+
+ double backunig(ngram ng);
+
+ double foreunig(ngram ng);
+
+ int adapt(char* ngtfile,int alev=1,double gis_step=0.4);
+
+ int scalefact(char* ngtfile);
+
+ int savescalefactor(char* filename);
+
+ double scalefact(ngram ng);
+
+ double prob(ngram ng,int size);
+ double prob(ngram ng,int size,double& fstar,double& lambda, double& bo);
+
+ double prob2(ngram ng,int size,double & fstar);
+
+ double txclprob(ngram ng,int size);
+
+ int saveASR(char *filename,int backoff,char* subdictfile=NULL);
+ int saveMT(char *filename,int backoff,char* subdictfile=NULL,int resolution=10000000,double decay=0.999900);
+
+ int saveARPA(char *filename,int backoff=0,char* subdictfile=NULL){
+ if (m_save_per_level){
+ cerr << " per level ...";
+ return saveARPA_per_level(filename, backoff, subdictfile);
+ }else{
+ cerr << " per word ...";
+ return saveARPA_per_word(filename, backoff, subdictfile);
+ }
+ }
+ int saveBIN(char *filename,int backoff=0,char* subdictfile=NULL,int mmap=0){
+ if (m_save_per_level){
+ cerr << " per level ...";
+ return saveBIN_per_level(filename, backoff, subdictfile, mmap);
+ }else{
+ cerr << " per word ...";
+ return saveBIN_per_word(filename, backoff, subdictfile, mmap);
+ }
+ }
+
+ inline void save_per_level(bool value){ m_save_per_level=value; }
+ inline bool save_per_level() const { return m_save_per_level; }
+
+ int netsize();
+
+ ~mdiadaptlm();
+
+ double myround(double x) {
+ long int value = (long int) x;
+ return (x-value)>0.500?value+1.0:(double)value;
+ }
+
+ inline static bool is_train_cache_enabled() {
+ VERBOSE(3,"inline static bool is_train_cache_enabled() " << mdiadaptlm_cache_enable << std::endl);
+ return mdiadaptlm_cache_enable;
+ }
+
+};
+
+}//namespace irstlm
+#endif
+
+
+
+
+
+
diff --git a/src/mempool.cpp b/src/mempool.cpp
new file mode 100644
index 0000000..36220f8
--- /dev/null
+++ b/src/mempool.cpp
@@ -0,0 +1,505 @@
+// $Id: mempool.cpp 302 2009-08-25 13:04:13Z nicolabertoldi $
+
+/******************************************************************************
+ IrstLM: IRST Language Model Toolkit
+ Copyright (C) 2006 Marcello Federico, ITC-irst Trento, Italy
+
+ This library is free software; you can redistribute it and/or
+ modify it under the terms of the GNU Lesser General Public
+ License as published by the Free Software Foundation; either
+ version 2.1 of the License, or (at your option) any later version.
+
+ This library is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+ Lesser General Public License for more details.
+
+ You should have received a copy of the GNU Lesser General Public
+ License along with this library; if not, write to the Free Software
+ Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+
+ ******************************************************************************/
+
+// An efficient memory pool manager
+// by M. Federico
+// Copyright Marcello Federico, ITC-irst, 1998
+
+#include <stdio.h>
+#include <cstring>
+#include <string.h>
+#include <cstdlib>
+#include <stdlib.h>
+#include <iostream>
+#include <ostream>
+#include "util.h"
+#include "mempool.h"
+
+using namespace std;
+
+/*! The pool contains:
+ - entries of size is
+ - tables for bs entries
+ */
+
+
+mempool::mempool(int is, int bs)
+{
+
+ // item size must be multiple of memory alignment step (4 bytes)
+ // example: is is=9 becomes i=12 (9 + 4 - 9 %4 )
+
+ is=(is>(int)sizeof(char *)?is:0);
+
+ is=is + sizeof(char *) - (is % sizeof(char *));
+
+ item_size = is;
+
+ block_size = bs;
+
+ true_size = is * bs;
+
+ block_list = new memnode;
+
+ block_list->block = new char[true_size];
+
+ memset(block_list->block,'0',true_size);
+
+ block_list->next = 0;
+
+ blocknum = 1;
+
+ entries = 0;
+
+ // build free list
+
+ char *ptr = free_list = block_list->block;
+
+ for (int i=0; i<block_size-1; i++) {
+ *(char **)ptr= ptr + item_size;
+ ptr+=item_size;
+ }
+ *(char **)ptr = NULL; //last item
+
+}
+
+
+char * mempool::allocate()
+{
+
+ char *ptr;
+
+ if (free_list==NULL) {
+ memnode *new_block = new memnode;
+
+ new_block->block = new char[true_size];
+
+ //memset(new_block->block,'0',true_size);
+
+ new_block->next = block_list;
+
+ block_list=new_block; // update block list
+
+ /* update free list */
+
+ ptr = free_list = block_list->block;
+
+ for (int i=0; i<block_size-1; i++) {
+ *(char **)ptr = ptr + item_size;
+ ptr = ptr + item_size;
+ }
+
+ *(char **)ptr=NULL;
+
+ blocknum++;
+ }
+
+ MY_ASSERT(free_list);
+
+ ptr = free_list;
+
+ free_list=*(char **)ptr;
+
+ *(char **)ptr=NULL; // reset the released item
+
+ entries++;
+
+ return ptr;
+
+}
+
+
+int mempool::free(char* addr)
+{
+
+ // do not check if it belongs to this pool !!
+ /*
+ memnode *list=block_list;
+ while ((list != NULL) &&
+ ((addr < list->block) ||
+ (addr >= (list->block + true_size))))
+ list=list->next;
+
+ if ((list==NULL) || (((addr - list->block) % item_size)!=0))
+ {
+ //cerr << "mempool::free-> addr does not belong to this pool\n";
+ return 0;
+ }
+ */
+
+ *(char **)addr=free_list;
+ free_list=addr;
+
+ entries--;
+
+ return 1;
+}
+
+
+mempool::~mempool()
+{
+ memnode *ptr;
+
+ while (block_list !=NULL) {
+ ptr=block_list->next;
+ delete [] block_list->block;
+ delete block_list;
+ block_list=ptr;
+ }
+
+}
+
+void mempool::map (ostream& co)
+{
+
+ co << "mempool memory map:\n";
+ //percorri piu` volte la lista libera
+
+ memnode *bl=block_list;
+ char *fl=free_list;
+
+ char* img=new char[block_size+1];
+ img[block_size]='\0';
+
+ while (bl !=NULL) {
+
+ memset(img,'#',block_size);
+
+ fl=free_list;
+ while (fl != NULL) {
+ if ((fl >= bl->block)
+ &&
+ (fl < bl->block + true_size)) {
+ img[(fl-bl->block)/item_size]='-';
+ }
+
+ fl=*(char **)fl;
+ }
+
+ co << img << "\n";
+ bl=bl->next;
+ }
+ delete [] img;
+}
+
+void mempool::stat()
+{
+
+ VERBOSE(1, "mempool class statistics\n"
+ << "entries " << entries
+ << " blocks " << blocknum
+ << " used memory " << (blocknum * true_size)/1024 << " Kb\n");
+}
+
+
+
+strstack::strstack(int bs)
+{
+
+ size=bs;
+ list=new memnode;
+
+ list->block=new char[size];
+
+ list->next=0;
+
+ memset(list->block,'\0',size);
+ idx=0;
+
+ waste=0;
+ memory=size;
+ entries=0;
+ blocknum=1;
+
+}
+
+
+void strstack::stat()
+{
+
+ VERBOSE(1, "strstack class statistics\n"
+ << "entries " << entries
+ << " blocks " << blocknum
+ << " used memory " << memory/1024 << " Kb\n");
+}
+
+
+const char *strstack::push(const char *s)
+{
+ int len=strlen(s);
+
+ if ((len+1) >= size) {
+ exit_error(IRSTLM_ERROR_DATA, "strstack::push string is too long");
+ };
+
+ if ((idx+len+1) >= size) {
+ //append a new block
+ //there must be space to
+ //put the index after
+ //the word
+
+ waste+=size-idx;
+ blocknum++;
+ memory+=size;
+
+ memnode* nd=new memnode;
+ nd->block=new char[size];
+ nd->next=list;
+
+ list=nd;
+
+ memset(list->block,'\0',size);
+
+ idx=0;
+
+ }
+
+ // append in current block
+
+ strcpy(&list->block[idx],s);
+
+ idx+=len+1;
+
+ entries++;
+
+ return &list->block[idx-len-1];
+
+}
+
+
+const char *strstack::pop()
+{
+
+ if (list==0) return 0;
+
+ if (idx==0) {
+
+ // free this block and go to next
+
+ memnode *ptr=list->next;
+
+ delete [] list->block;
+ delete list;
+
+ list=ptr;
+
+ if (list==0)
+ return 0;
+ else
+ idx=size-1;
+ }
+
+ //go back to first non \0
+ while (idx>0)
+ if (list->block[idx--]!='\0')
+ break;
+
+ //go back to first \0
+ while (idx>0)
+ if (list->block[idx--]=='\0')
+ break;
+
+ entries--;
+
+ if (list->block[idx+1]=='\0') {
+ idx+=2;
+ memset(&list->block[idx],'\0',size-idx);
+ return &list->block[idx];
+ } else {
+ idx=0;
+ memset(&list->block[idx],'\0',size);
+ return &list->block[0];
+ }
+}
+
+
+const char *strstack::top()
+{
+
+ int tidx=idx;
+ memnode *tlist=list;
+
+ if (tlist==0) return 0;
+
+ if (idx==0) {
+
+ tlist=tlist->next;
+
+ if (tlist==0) return 0;
+
+ tidx=size-1;
+ }
+
+ //go back to first non \0
+ while (tidx>0)
+ if (tlist->block[tidx--]!='\0')
+ break;
+
+ //aaa\0bbb\0\0\0\0
+
+ //go back to first \0
+ while (tidx>0)
+ if (tlist->block[tidx--]=='\0')
+ break;
+
+ if (tlist->block[tidx+1]=='\0') {
+ tidx+=2;
+ return &tlist->block[tidx];
+ } else {
+ tidx=0;
+ return &tlist->block[0];
+ }
+
+}
+
+
+strstack::~strstack()
+{
+ memnode *ptr;
+ while (list !=NULL) {
+ ptr=list->next;
+ delete [] list->block;
+ delete list;
+ list=ptr;
+ }
+}
+
+
+storage::storage(int maxsize,int blocksize)
+{
+ newmemory=0;
+ newcalls=0;
+ setsize=maxsize;
+ poolsize=blocksize; //in bytes
+ poolset=new mempool* [setsize+1];
+ for (int i=0; i<=setsize; i++)
+ poolset[i]=NULL;
+}
+
+
+storage::~storage()
+{
+ for (int i=0; i<=setsize; i++)
+ if (poolset[i])
+ delete poolset[i];
+ delete [] poolset;
+}
+
+char *storage::allocate(int size)
+{
+
+ if (size<=setsize) {
+ if (!poolset[size]) {
+ poolset[size]=new mempool(size,poolsize/size);
+ }
+ return poolset[size]->allocate();
+ } else {
+
+ newmemory+=size+8;
+ newcalls++;
+ char* p=(char *)calloc(sizeof(char),size);
+ if (p==NULL) {
+ exit_error(IRSTLM_ERROR_MEMORY, "storage::alloc insufficient memory");
+ }
+ return p;
+ }
+}
+
+char *storage::reallocate(char *oldptr,int oldsize,int newsize)
+{
+
+ char *newptr;
+
+ MY_ASSERT(newsize>oldsize);
+
+ if (oldsize<=setsize) {
+ if (newsize<=setsize) {
+ if (!poolset[newsize])
+ poolset[newsize]=new mempool(newsize,poolsize/newsize);
+ newptr=poolset[newsize]->allocate();
+ memset((char*)newptr,0,newsize);
+ } else
+ newptr=(char *)calloc(sizeof(char),newsize);
+
+ if (oldptr && oldsize) {
+ memcpy(newptr,oldptr,oldsize);
+ poolset[oldsize]->free(oldptr);
+ }
+ } else {
+ newptr=(char *)realloc(oldptr,newsize);
+ if (newptr==oldptr)
+ cerr << "r\b";
+ else
+ cerr << "a\b";
+ }
+ if (newptr==NULL) {
+ exit_error(IRSTLM_ERROR_MEMORY,"storage::realloc insufficient memory");
+ }
+
+ return newptr;
+}
+
+int storage::free(char *addr,int size)
+{
+
+ /*
+ while(size<=setsize){
+ if (poolset[size] && poolset[size]->free(addr))
+ break;
+ size++;
+ }
+ */
+
+ if (size>setsize)
+ return free(addr),1;
+ else {
+ poolset[size] && poolset[size]->free(addr);
+ }
+ return 1;
+}
+
+
+void storage::stat()
+{
+ IFVERBOSE(1){
+ int used=0;
+ int memory=sizeof(char *) * setsize;
+ int waste=0;
+
+ for (int i=0; i<=setsize; i++)
+ if (poolset[i]) {
+ used++;
+ memory+=poolset[i]->used();
+ waste+=poolset[i]->wasted();
+ }
+
+ VERBOSE(1, "storage class statistics\n"
+ << "alloc entries " << newcalls
+ << " used memory " << newmemory/1024 << "Kb\n"
+ << "mpools " << setsize
+ << " active " << used
+ << " used memory " << memory/1024 << "Kb"
+ << " wasted " << waste/1024 << "Kb\n");
+ }
+}
+
+
diff --git a/src/mempool.h b/src/mempool.h
new file mode 100644
index 0000000..ba44d21
--- /dev/null
+++ b/src/mempool.h
@@ -0,0 +1,194 @@
+// $Id: mempool.h 383 2010-04-23 15:29:28Z nicolabertoldi $
+
+/******************************************************************************
+ IrstLM: IRST Language Model Toolkit
+ Copyright (C) 2006 Marcello Federico, ITC-irst Trento, Italy
+
+ This library is free software; you can redistribute it and/or
+ modify it under the terms of the GNU Lesser General Public
+ License as published by the Free Software Foundation; either
+ version 2.1 of the License, or (at your option) any later version.
+
+ This library is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+ Lesser General Public License for more details.
+
+ You should have received a copy of the GNU Lesser General Public
+ License along with this library; if not, write to the Free Software
+ Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+
+******************************************************************************/
+
+// An efficient memory manager
+// by M. Federico
+// Copyright Marcello Federico, ITC-irst, 1998
+
+#ifndef MF_MEMPOOL_H
+#define MF_MEMPOOL_H
+
+#ifndef NULL
+const int NULL=0;
+#endif
+
+#include <iostream> // std::ostream
+
+//! Memory block
+/*! This can be used by:
+- mempool to store items of fixed size
+- strstack to store strings of variable size
+*/
+
+
+#define MP_BLOCK_SIZE 1000000
+
+class memnode
+{
+ friend class mempool; //!< grant access
+ friend class strstack; //!< grant access
+ char *block; //!< block of memory
+ memnode *next; //!< next block ptr
+public:
+ //! Creates a memory node
+ memnode():block(NULL), next(NULL){};
+
+ //! Destroys memory node
+ ~memnode(){};
+};
+
+
+//! Memory pool
+
+/*! A memory pool is composed of:
+ - a linked list of block_num memory blocks
+ - each block might contain up to block_size items
+ - each item is made of exactly item_size bytes
+*/
+
+class mempool
+{
+ int block_size; //!< number of entries per block
+ int item_size; //!< number of bytes per entry
+ int true_size; //!< number of bytes per block
+ memnode* block_list; //!< list of blocks
+ char* free_list; //!< free entry list
+ int entries; //!< number of stored entries
+ int blocknum; //!< number of allocated blocks
+public:
+
+ //! Creates a memory pool
+ mempool(int is, int bs=MP_BLOCK_SIZE);
+
+ //! Destroys memory pool
+ ~mempool();
+
+ //! Prints a map of memory occupancy
+ void map(std::ostream& co);
+
+ //! Allocates a single memory entry
+ char *allocate();
+
+ //! Frees a single memory entry
+ int free(char* addr);
+
+ //! Prints statistics about this mempool
+ void stat();
+
+ //! Returns effectively used memory (bytes)
+ /*! includes 8 bytes required by each call of new */
+
+ int used() const {
+ return blocknum * (true_size + 8);
+ }
+
+ //! Returns amount of wasted memory (bytes)
+ int wasted() const {
+ return used()-(entries * item_size);
+ }
+};
+
+//! A stack to store strings
+
+/*!
+ The stack is composed of
+ - a list of blocks memnode of fixed size
+ - attribute blocknum tells the block on top
+ - attribute idx tells position of the top string
+*/
+
+class strstack
+{
+ memnode* list; //!< list of memory blocks
+ int size; //!< size of each block
+ int idx; //!< index of last stored string
+ int waste; //!< current waste of memory
+ int memory; //!< current use of memory
+ int entries; //!< current number of stored strings
+ int blocknum; //!< current number of used blocks
+
+public:
+
+ strstack(int bs=1000);
+
+ ~strstack();
+
+ const char *push(const char *s);
+
+ const char *pop();
+
+ const char *top();
+
+ void stat();
+
+ int used() const {
+ return memory;
+ }
+
+ int wasted() const {
+ return waste;
+ }
+
+};
+
+
+//! Manages multiple memory pools
+
+/*!
+ This class permits to manage memory pools
+ with items up to a specified size.
+ - items within the allowed range are stored in memory pools
+ - items larger than the limit are allocated with new
+*/
+
+
+class storage
+{
+ mempool **poolset; //!< array of memory pools
+ int setsize; //!< number of memory pools/maximum elem size
+ int poolsize; //!< size of each block
+ int newmemory; //!< stores amount of used memory
+ int newcalls; //!< stores number of allocated blocks
+public:
+
+ //! Creates storage
+ storage(int maxsize,int blocksize);
+
+ //! Destroys storage
+ ~storage();
+
+ /* names of below functions have been changed so as not to interfere with macros for malloc/realloc/etc -- EVH */
+
+ //! Allocates memory
+ char *allocate(int size);
+
+ //! Realloc memory
+ char *reallocate(char *oldptr,int oldsize,int newsize);
+
+ //! Frees memory of an entry
+ int free(char *addr,int size=0);
+
+ //! Prints statistics about storage
+ void stat();
+};
+
+#endif
diff --git a/src/mfstream.cpp b/src/mfstream.cpp
new file mode 100644
index 0000000..4ae076c
--- /dev/null
+++ b/src/mfstream.cpp
@@ -0,0 +1,219 @@
+// $Id: mfstream.cpp 294 2009-08-19 09:57:27Z mfederico $
+
+/******************************************************************************
+ IrstLM: IRST Language Model Toolkit, compile LM
+ Copyright (C) 2006 Marcello Federico, ITC-irst Trento, Italy
+
+ This library is free software; you can redistribute it and/or
+ modify it under the terms of the GNU Lesser General Public
+ License as published by the Free Software Foundation; either
+ version 2.1 of the License, or (at your option) any later version.
+
+ This library is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+ Lesser General Public License for more details.
+
+ You should have received a copy of the GNU Lesser General Public
+ License along with this library; if not, write to the Free Software
+ Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+
+ ******************************************************************************/
+
+#include <iostream>
+#include <fstream>
+#include <streambuf>
+#include <cstdio>
+#include "util.h"
+#include "mfstream.h"
+#include "gzfilebuf.h"
+
+using namespace std;
+
+void mfstream::open(const char *name,openmode mode)
+{
+
+ char cmode[10];
+
+ if (strchr(name,' ')!=0) {
+ if (mode & ios::in)
+ strcpy(cmode,"r");
+ else if (mode & ios::out)
+ strcpy(cmode,"w");
+ else if (mode & ios::app)
+ strcpy(cmode,"a");
+ else {
+ exit_error(IRSTLM_ERROR_IO, "cannot open file");
+ }
+ _cmd=1;
+ strcpy(_cmdname,name);
+ _FILE=popen(name,cmode);
+ buf=new fdbuf(fileno(_FILE));
+ iostream::rdbuf((streambuf*) buf);
+ } else {
+ _cmd=0;
+ fstream::open(name,mode);
+ }
+
+}
+
+
+void mfstream::close()
+{
+ if (_cmd==1) {
+ pclose(_FILE);
+ delete buf;
+ } else {
+ fstream::clear();
+ fstream::close();
+ }
+ _cmd=2;
+}
+
+
+
+int mfstream::swapbytes(char *p, int sz, int n)
+{
+ char c,
+ *l,
+ *h;
+
+ if((n<1) ||(sz<2)) return 0;
+ for(; n--; p+=sz) for(h=(l=p)+sz; --h>l; l++) {
+ c=*h;
+ *h=*l;
+ *l=c;
+ }
+ return 0;
+
+};
+
+
+mfstream& mfstream::iwritex(streampos loc,void *ptr,int size,int n)
+{
+ streampos pos=tellp();
+
+ seekp(loc);
+
+ writex(ptr,size,n);
+
+ seekp(pos);
+
+ return *this;
+
+}
+
+
+mfstream& mfstream::readx(void *p, int sz,int n)
+{
+ if(!read((char *)p, sz * n)) return *this;
+
+ if(*(short *)"AB"==0x4241) {
+ swapbytes((char*)p, sz,n);
+ }
+
+ return *this;
+}
+
+mfstream& mfstream::writex(void *p, int sz,int n)
+{
+ if(*(short *)"AB"==0x4241) {
+ swapbytes((char*)p, sz,n);
+ }
+
+ write((char *)p, sz * n);
+
+ if(*(short *)"AB"==0x4241) swapbytes((char*)p, sz,n);
+
+ return *this;
+}
+
+//! Tells current position within a file
+streampos mfstream::tellp() {
+ if (_cmd!=0)
+ exit_error(IRSTLM_ERROR_IO, "mfstream::tellp tellp not allowed on commands");
+
+ return (streampos) fstream::tellg();
+}
+
+//! Seeks a position within a file
+mfstream& mfstream::seekp(streampos loc) {
+ if (_cmd==0)
+ fstream::seekg(loc);
+ else {
+ exit_error(IRSTLM_ERROR_IO, "mfstream::seekp seekp not allowed on commands");
+ }
+ return *this;
+}
+
+//! Reopens an input stream
+mfstream& mfstream::reopen() {
+
+ if (_mode != in) {
+ exit_error(IRSTLM_ERROR_IO, "mfstream::reopen() openmode must be ios:in");
+ }
+
+ if (strlen(_cmdname)>0) {
+ char *a=new char[strlen(_cmdname)+1];
+ strcpy(a,_cmdname);
+ cerr << "close/open " << a <<"\n";
+ close();
+ open(a,ios::in);
+ delete []a;
+ } else{
+ seekp(0);
+ }
+ return *this;
+}
+
+
+
+inputfilestream::inputfilestream(const std::string &filePath)
+: std::istream(0),
+m_streambuf(0)
+{
+ //check if file is readable
+ std::filebuf* fb = new std::filebuf();
+ _good=(fb->open(filePath.c_str(), std::ios::in)!=NULL);
+
+ if (filePath.size() > 3 &&
+ filePath.substr(filePath.size() - 3, 3) == ".gz") {
+ fb->close();
+ delete fb;
+ m_streambuf = new gzfilebuf(filePath.c_str());
+ } else {
+ m_streambuf = fb;
+ }
+ this->init(m_streambuf);
+}
+
+inputfilestream::~inputfilestream()
+{
+ delete m_streambuf;
+ m_streambuf = 0;
+}
+
+void inputfilestream::close()
+{
+}
+
+
+
+/*
+ int main()
+ {
+
+ char word[1000];
+
+ mfstream inp("cat pp",ios::in);
+ mfbstream outp("aa",ios::out,100);
+
+ while (inp >> word){
+ outp << word << "\n";
+ cout << word << "\n";
+ }
+
+
+ }
+
+ */
diff --git a/src/mfstream.h b/src/mfstream.h
new file mode 100644
index 0000000..89b6f3c
--- /dev/null
+++ b/src/mfstream.h
@@ -0,0 +1,218 @@
+// $Id: mfstream.h 383 2010-04-23 15:29:28Z nicolabertoldi $
+
+/******************************************************************************
+IrstLM: IRST Language Model Toolkit, compile LM
+Copyright (C) 2006 Marcello Federico, ITC-irst Trento, Italy
+
+This library is free software; you can redistribute it and/or
+modify it under the terms of the GNU Lesser General Public
+License as published by the Free Software Foundation; either
+version 2.1 of the License, or (at your option) any later version.
+
+This library is distributed in the hope that it will be useful,
+but WITHOUT ANY WARRANTY; without even the implied warranty of
+MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+Lesser General Public License for more details.
+
+You should have received a copy of the GNU Lesser General Public
+License along with this library; if not, write to the Free Software
+Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+
+******************************************************************************/
+#include <stdio.h>
+#include <stdlib.h>
+#include <unistd.h>
+#include <string.h>
+#include <iostream>
+#include <cstring>
+#include <cstdlib>
+#include <fstream>
+#include <streambuf>
+#include <cstdio>
+
+using namespace std;
+
+#ifndef MF_STREAM_H
+#define MF_STREAM_H
+
+extern "C" {
+ ssize_t write (int fd, const void* buf, size_t num);
+ ssize_t read (int fd, void* buf, size_t num);
+ FILE *popen(const char *command, const char *type);
+ int pclose(FILE *stream);
+ int fseek( FILE *stream, long offset, int whence);
+ long ftell( FILE *stream);
+};
+
+
+//! File description for I/O stream buffer
+class fdbuf : public std::streambuf
+{
+
+protected:
+ int fd; // file descriptor
+
+ // write one character
+ virtual int_type overflow (int_type c) {
+ char z = c;
+ if (c != EOF) {
+ if (write (fd, &z, 1) != 1) {
+ return EOF;
+ }
+ }
+ //cerr << "overflow: \n";
+ //cerr << "pptr: " << (int) pptr() << "\n";
+ return c;
+ }
+
+ // write multiple characters
+ virtual
+ std::streamsize xsputn (const char* s,
+ std::streamsize num) {
+ return write(fd,s,num);
+
+ }
+
+ virtual streampos seekpos ( streampos /* unused parameter: sp */, ios_base::openmode /* unused parameter: which */= ios_base::in | ios_base::out ) {
+ std::cerr << "mfstream::seekpos is not implemented" << std::endl;;
+
+ return (streampos) 0;
+ }
+
+ //read one character
+ virtual int_type underflow () {
+ // is read position before end of buffer?
+ if (gptr() < egptr()) {
+ return traits_type::to_int_type(*gptr());
+ }
+
+ /* process size of putback area
+ * - use number of characters read
+ * - but at most four
+ */
+ int numPutback;
+ numPutback = gptr() - eback();
+ if (numPutback > 4) {
+ numPutback = 4;
+ }
+
+ /* copy up to four characters previously read into
+ * the putback buffer (area of first four characters)
+ */
+ std::memmove (buffer+(4-numPutback), gptr()-numPutback,
+ numPutback);
+
+ // read new characters
+ int num;
+ num = read (fd, buffer+4, bufferSize-4);
+ if (num <= 0) {
+ // ERROR or EOF
+ return EOF;
+ }
+
+ // reset buffer pointers
+ setg (buffer+(4-numPutback), // beginning of putback area
+ buffer+4, // read position
+ buffer+4+num); // end of buffer
+
+ // return next character
+ return traits_type::to_int_type(*gptr());
+ }
+
+
+ // read multiple characters
+ virtual
+ std::streamsize xsgetn (char* s,
+ std::streamsize num) {
+ return read(fd,s,num);
+ }
+
+ static const int bufferSize = 10; // size of the data buffer
+ char buffer[bufferSize]; // data buffer
+
+public:
+
+ // constructor
+ fdbuf (int _fd) : fd(_fd) {
+ setg (buffer+4, // beginning of putback area
+ buffer+4, // read position
+ buffer+4); // end position
+ }
+
+};
+
+
+
+//! Extension of fstream to commands
+
+class mfstream : public std::fstream
+{
+
+protected:
+ fdbuf* buf;
+ int _cmd;
+ openmode _mode;
+ FILE* _FILE;
+
+ char _cmdname[500];
+
+ int swapbytes(char *p, int sz, int n);
+
+public:
+
+ //! Creates and opens a file/command stream without a specified nmode
+ mfstream () : std::fstream(), buf(NULL), _cmd(0), _FILE(NULL) {
+ _cmdname[0]='\0';
+ }
+
+ //! Creates and opens a file/command stream in a specified nmode
+ mfstream (const char* name,openmode mode) : std::fstream() {
+ _cmdname[0]='\0';
+ _mode=mode;
+ open(name,mode);
+ }
+
+ //! Closes and destroys a file/command stream
+ ~mfstream() {
+ if (_cmd<2) close();
+ }
+
+ //! Opens an existing mfstream
+ void open(const char *name,openmode mode);
+
+ //! Closes an existing mfstream
+ void close();
+
+ //! Write function for machine-independent byte order
+ mfstream& writex(void *p, int sz,int n=1);
+
+ //! Read function for machine-independent byte order
+ mfstream& readx(void *p, int sz,int n=1);
+
+ //! Write function at a given stream position for machine-independent byte order
+ mfstream& iwritex(streampos loc,void *ptr,int size,int n=1);
+
+ //! Tells current position within a file
+ streampos tellp();
+
+ //! Seeks a position within a file
+ mfstream& seekp(streampos loc);
+
+ //! Reopens an input stream
+ mfstream& reopen();
+};
+
+class inputfilestream : public std::istream
+{
+protected:
+ std::streambuf *m_streambuf;
+ bool _good;
+public:
+
+ inputfilestream(const std::string &filePath);
+ ~inputfilestream();
+ inline bool good() { return _good; }
+ void close();
+};
+
+#endif
diff --git a/src/mixture.cpp b/src/mixture.cpp
new file mode 100644
index 0000000..de24433
--- /dev/null
+++ b/src/mixture.cpp
@@ -0,0 +1,581 @@
+/******************************************************************************
+IrstLM: IRST Language Model Toolkit
+Copyright (C) 2006 Marcello Federico, ITC-irst Trento, Italy
+
+This library is free software; you can redistribute it and/or
+modify it under the terms of the GNU Lesser General Public
+License as published by the Free Software Foundation; either
+version 2.1 of the License, or (at your option) any later version.
+
+This library is distributed in the hope that it will be useful,
+but WITHOUT ANY WARRANTY; without even the implied warranty of
+MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+Lesser General Public License for more details.
+
+You should have received a copy of the GNU Lesser General Public
+License along with this library; if not, write to the Free Software
+Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+
+******************************************************************************/
+
+
+#include <cmath>
+#include <sstream>
+#include "mfstream.h"
+#include "mempool.h"
+#include "dictionary.h"
+#include "n_gram.h"
+#include "ngramtable.h"
+#include "interplm.h"
+#include "normcache.h"
+#include "ngramcache.h"
+#include "mdiadapt.h"
+#include "shiftlm.h"
+#include "linearlm.h"
+#include "mixture.h"
+#include "cmd.h"
+#include "util.h"
+
+using namespace std;
+
+namespace irstlm {
+//
+//Mixture interpolated language model
+//
+
+static Enum_T SLmTypeEnum [] = {
+ { (char*)"ImprovedKneserNey", IMPROVED_KNESER_NEY },
+ { (char*)"ikn", IMPROVED_KNESER_NEY },
+ { (char*)"KneserNey", KNESER_NEY },
+ { (char*)"kn", KNESER_NEY },
+ { (char*)"ModifiedShiftBeta", MOD_SHIFT_BETA },
+ { (char*)"msb", MOD_SHIFT_BETA },
+ { (char*)"ImprovedShiftBeta", IMPROVED_SHIFT_BETA },
+ { (char*)"isb", IMPROVED_SHIFT_BETA },
+ { (char*)"InterpShiftBeta", SHIFT_BETA },
+ { (char*)"ShiftBeta", SHIFT_BETA },
+ { (char*)"sb", SHIFT_BETA },
+ { (char*)"InterpShiftOne", SHIFT_ONE },
+ { (char*)"ShiftOne", SHIFT_ONE },
+ { (char*)"s1", SHIFT_ONE },
+ { (char*)"InterpShiftZero", SHIFT_ZERO },
+ { (char*)"s0", SHIFT_ZERO },
+ { (char*)"LinearWittenBell", LINEAR_WB },
+ { (char*)"wb", LINEAR_WB },
+ { (char*)"Mixture", MIXTURE },
+ { (char*)"mix", MIXTURE },
+ END_ENUM
+};
+
+
+mixture::mixture(bool fulltable,char* sublminfo,int depth,int prunefreq,char* ipfile,char* opfile):
+ mdiadaptlm((char *)NULL,depth)
+ {
+
+ prunethresh=prunefreq;
+ ipfname=ipfile;
+ opfname=opfile;
+ usefulltable=fulltable;
+
+ mfstream inp(sublminfo,ios::in );
+ if (!inp) {
+ std::stringstream ss_msg;
+ ss_msg << "cannot open " << sublminfo;
+ exit_error(IRSTLM_ERROR_IO, ss_msg.str());
+ }
+
+ char line[MAX_LINE];
+ inp.getline(line,MAX_LINE);
+
+ sscanf(line,"%d",&numslm);
+
+ sublm=new interplm* [numslm];
+
+ cerr << "WARNING: Parameters PruneSingletons (ps) and PruneTopSingletons (pts) are not taken into account for this type of LM (mixture); please specify the singleton pruning policy for each submodel using parameters \"-sps\" and \"-spts\" in the configuraton file\n";
+
+ int max_npar=6;
+ for (int i=0; i<numslm; i++) {
+ char **par=new char*[max_npar];
+ par[0]=new char[BUFSIZ];
+ par[0][0]='\0';
+
+ inp.getline(line,MAX_LINE);
+
+ const char *const wordSeparators = " \t\r\n";
+ char *word = strtok(line, wordSeparators);
+ int j = 1;
+
+ while (word){
+ if (j>max_npar){
+ std::stringstream ss_msg;
+ ss_msg << "Too many parameters (expected " << max_npar << ")";
+ exit_error(IRSTLM_ERROR_DATA, ss_msg.str());
+ }
+ par[j] = new char[MAX_LINE];
+ strcpy(par[j],word);
+ // std::cerr << "par[j]:|" << par[j] << "|" << std::endl;
+ word = strtok(0, wordSeparators);
+ j++;
+ }
+
+ int actual_npar = j;
+
+ char *subtrainfile;
+ int slmtype;
+ bool subprunesingletons;
+ bool subprunetopsingletons;
+ char *subprune_thr_str=NULL;
+
+ int subprunefreq;
+
+ DeclareParams((char*)
+ "SubLanguageModelType",CMDENUMTYPE|CMDMSG, &slmtype, SLmTypeEnum, "type of the sub LM",
+ "slm",CMDENUMTYPE|CMDMSG, &slmtype, SLmTypeEnum, "type of the sub LM",
+ "sTrainOn",CMDSTRINGTYPE|CMDMSG, &subtrainfile, "training file of the sub LM",
+ "str",CMDSTRINGTYPE|CMDMSG, &subtrainfile, "training file of the sub LM",
+ "sPruneThresh",CMDSUBRANGETYPE|CMDMSG, &subprunefreq, 0, 1000, "threshold for pruning the sub LM",
+ "sp",CMDSUBRANGETYPE|CMDMSG, &subprunefreq, 0, 1000, "threshold for pruning the sub LM",
+ "sPruneSingletons",CMDBOOLTYPE|CMDMSG, &subprunesingletons, "boolean flag for pruning of singletons of the sub LM (default is true)",
+ "sps",CMDBOOLTYPE|CMDMSG, &subprunesingletons, "boolean flag for pruning of singletons of the sub LM (default is true)",
+ "sPruneTopSingletons",CMDBOOLTYPE|CMDMSG, &subprunetopsingletons, "boolean flag for pruning of singletons at the top level of the sub LM (default is false)",
+ "spts",CMDBOOLTYPE|CMDMSG, &subprunetopsingletons, "boolean flag for pruning of singletons at the top level of the sub LM (default is false)",
+ "sPruneFrequencyThreshold",CMDSTRINGTYPE|CMDMSG, &subprune_thr_str, "pruning frequency threshold for each level of the sub LM; comma-separated list of values; (default is \"0 0 ... 0\", for all levels)",
+ "spft",CMDSTRINGTYPE|CMDMSG, &subprune_thr_str, "pruning frequency threshold for each level of the sub LM; comma-separated list of values; (default is \"0 0 ... 0\", for all levels)",
+ (char *)NULL );
+
+ subtrainfile=NULL;
+ slmtype=0;
+ subprunefreq=0;
+ subprunesingletons=true;
+ subprunetopsingletons=false;
+
+ GetParams(&actual_npar, &par, (char*) NULL);
+
+
+ if (!slmtype) {
+ std::stringstream ss_msg;
+ ss_msg << "The type (-slm) for sub LM number " << i+1 << " is not specified" ;
+ exit_error(IRSTLM_ERROR_DATA, ss_msg.str());
+ }
+
+ if (!subtrainfile) {
+ std::stringstream ss_msg;
+ ss_msg << "The file (-str) for sub lm number " << i+1 << " is not specified";
+ exit_error(IRSTLM_ERROR_DATA, ss_msg.str());
+ }
+
+ if (subprunefreq==-1) {
+ std::stringstream ss_msg;
+ ss_msg << "The prune threshold (-sp) for sub lm number " << i+1 << " is not specified";
+ exit_error(IRSTLM_ERROR_DATA, ss_msg.str());
+ }
+
+ switch (slmtype) {
+
+ case LINEAR_WB:
+ sublm[i]=new linearwb(subtrainfile,depth,subprunefreq,IMPROVEDSHIFTBETA_I);
+ break;
+
+ case SHIFT_BETA:
+ sublm[i]=new shiftbeta(subtrainfile,depth,subprunefreq,-1,SHIFTBETA_I);
+ break;
+
+ case KNESER_NEY:
+ // lm=new kneserney(subtrainfile,depth,subprunefreq,-1,KNESERNEY_I);
+
+ break;
+
+ case MOD_SHIFT_BETA:
+ case IMPROVED_KNESER_NEY:
+ sublm[i]=new improvedkneserney(subtrainfile,depth,subprunefreq,IMPROVEDKNESERNEY_I);
+ break;
+
+ case IMPROVED_SHIFT_BETA:
+ sublm[i]=new improvedshiftbeta(subtrainfile,depth,subprunefreq,IMPROVEDSHIFTBETA_I);
+ break;
+
+ case SHIFT_ONE:
+ sublm[i]=new shiftone(subtrainfile,depth,subprunefreq,SIMPLE_I);
+ break;
+
+ case MIXTURE:
+ sublm[i]=new mixture(usefulltable,subtrainfile,depth,subprunefreq);
+ break;
+
+ default:
+ exit_error(IRSTLM_ERROR_DATA, "not implemented yet");
+ };
+
+ sublm[i]->prunesingletons(subprunesingletons==true);
+ sublm[i]->prunetopsingletons(subprunetopsingletons==true);
+
+ if (subprunetopsingletons==true)
+ //apply most specific pruning method
+ sublm[i]->prunesingletons(false);
+
+ if (subprune_thr_str)
+ sublm[i]->set_prune_ngram(subprune_thr_str);
+
+
+ cerr << "eventually generate OOV code of sub lm[" << i << "]\n";
+ sublm[i]->dict->genoovcode();
+
+ //create super dictionary
+ dict->augment(sublm[i]->dict);
+
+ //creates the super n-gram table
+ if(usefulltable) augment(sublm[i]);
+
+ cerr << "super table statistics\n";
+ stat(2);
+ }
+
+ cerr << "eventually generate OOV code of the mixture\n";
+ dict->genoovcode();
+ cerr << "dict size of the mixture:" << dict->size() << "\n";
+ //tying parameters
+ k1=2;
+ k2=10;
+ };
+
+double mixture::reldist(double *l1,double *l2,int n)
+{
+ double dist=0.0,size=0.0;
+ for (int i=0; i<n; i++) {
+ dist+=(l1[i]-l2[i])*(l1[i]-l2[i]);
+ size+=l1[i]*l1[i];
+ }
+ return sqrt(dist/size);
+}
+
+
+double rand01()
+{
+ return (double)rand()/(double)RAND_MAX;
+}
+
+int mixture::genpmap()
+{
+ dictionary* d=sublm[0]->dict;
+
+ cerr << "Computing parameters mapping: ..." << d->size() << " ";
+ pm=new int[d->size()];
+ //initialize
+ for (int i=0; i<d->size(); i++) pm[i]=0;
+
+ pmax=k2-k1+1; //update # of parameters
+
+ for (int w=0; w<d->size(); w++) {
+ int f=d->freq(w);
+ if ((f>k1) && (f<=k2)) pm[w]=f-k1;
+ else if (f>k2) {
+ pm[w]=pmax++;
+ }
+ }
+ cerr << "pmax " << pmax << " ";
+ return 1;
+}
+
+int mixture::pmap(ngram ng,int lev)
+{
+
+ ngram h(sublm[0]->dict);
+ h.trans(ng);
+
+ if (lev<=1) return 0;
+ //get the last word of history
+ if (!sublm[0]->get(h,2,1)) return 0;
+ return (int) pm[*h.wordp(2)];
+}
+
+
+int mixture::savepar(char* opf)
+{
+ mfstream out(opf,ios::out);
+
+ cerr << "saving parameters in " << opf << "\n";
+ out << lmsize() << " " << pmax << "\n";
+
+ for (int i=0; i<=lmsize(); i++)
+ for (int j=0; j<pmax; j++)
+ out.writex(l[i][j],sizeof(double),numslm);
+
+
+ return 1;
+}
+
+
+int mixture::loadpar(char* ipf)
+{
+
+ mfstream inp(ipf,ios::in);
+
+ if (!inp) {
+ std::stringstream ss_msg;
+ ss_msg << "cannot open file: " << ipf;
+ exit_error(IRSTLM_ERROR_IO, ss_msg.str());
+ }
+
+ cerr << "loading parameters from " << ipf << "\n";
+
+ // check compatibility
+ char header[100];
+ inp.getline(header,100);
+ int value1,value2;
+ sscanf(header,"%d %d",&value1,&value2);
+
+ if (value1 != lmsize() || value2 != pmax) {
+ std::stringstream ss_msg;
+ ss_msg << "parameter file " << ipf << " is incompatible";
+ exit_error(IRSTLM_ERROR_DATA, ss_msg.str());
+ }
+
+ for (int i=0; i<=lmsize(); i++)
+ for (int j=0; j<pmax; j++)
+ inp.readx(l[i][j],sizeof(double),numslm);
+
+ return 1;
+}
+
+int mixture::train()
+{
+
+ double zf;
+
+ srand(1333);
+
+ genpmap();
+
+ if (dub()<dict->size()) {
+ std::stringstream ss_msg;
+ ss_msg << "\nERROR: DUB value is too small: the LM will possibly compute wrong probabilities if sub-LMs have different vocabularies!\n";
+ ss_msg << "This exception should already have been handled before!!!\n";
+ exit_error(IRSTLM_ERROR_MODEL, ss_msg.str());
+ }
+
+ cerr << "mixlm --> DUB: " << dub() << endl;
+ for (int i=0; i<numslm; i++) {
+ cerr << i << " sublm --> DUB: " << sublm[i]->dub() << endl;
+ cerr << "eventually generate OOV code ";
+ cerr << sublm[i]->dict->encode(sublm[i]->dict->OOV()) << "\n";
+ sublm[i]->train();
+ }
+
+ //initialize parameters
+
+ for (int i=0; i<=lmsize(); i++) {
+ l[i]=new double*[pmax];
+ for (int j=0; j<pmax; j++) {
+ l[i][j]=new double[numslm];
+ for (int k=0; k<numslm; k++)
+ l[i][j][k]=1.0/(double)numslm;
+ }
+ }
+
+ if (ipfname) {
+ //load parameters from file
+ loadpar(ipfname);
+ } else {
+ //start training of mixture model
+
+ double oldl[pmax][numslm];
+ char alive[pmax],used[pmax];
+ int totalive;
+
+ ngram ng(sublm[0]->dict);
+
+ for (int lev=1; lev<=lmsize(); lev++) {
+
+ zf=sublm[0]->zerofreq(lev);
+
+ cerr << "Starting training at lev:" << lev << "\n";
+
+ for (int i=0; i<pmax; i++) {
+ alive[i]=1;
+ used[i]=0;
+ }
+ totalive=1;
+ int iter=0;
+ while (totalive && (iter < 20) ) {
+
+ iter++;
+
+ for (int i=0; i<pmax; i++)
+ if (alive[i])
+ for (int j=0; j<numslm; j++) {
+ oldl[i][j]=l[lev][i][j];
+ l[lev][i][j]=1.0/(double)numslm;
+ }
+
+ sublm[0]->scan(ng,INIT,lev);
+ while(sublm[0]->scan(ng,CONT,lev)) {
+
+ //do not include oov for unigrams
+ if ((lev==1) && (*ng.wordp(1)==sublm[0]->dict->oovcode()))
+ continue;
+
+ int par=pmap(ng,lev);
+ used[par]=1;
+
+ //controllo se aggiornare il parametro
+ if (alive[par]) {
+
+ double backoff=(lev>1?prob(ng,lev-1):1); //backoff
+ double denom=0.0;
+ double* numer = new double[numslm];
+ double fstar,lambda;
+
+ //int cv=(int)floor(zf * (double)ng.freq + rand01());
+ //int cv=1; //old version of leaving-one-out
+ int cv=(int)floor(zf * (double)ng.freq)+1;
+ //int cv=1; //old version of leaving-one-out
+ //if (lev==3)q
+
+ //if (iter>10)
+ // cout << ng
+ // << " backoff " << backoff
+ // << " level " << lev
+ // << "\n";
+
+ for (int i=0; i<numslm; i++) {
+
+ //use cv if i=0
+
+ sublm[i]->discount(ng,lev,fstar,lambda,(i==0)*(cv));
+ numer[i]=oldl[par][i]*(fstar + lambda * backoff);
+
+ ngram ngslm(sublm[i]->dict);
+ ngslm.trans(ng);
+ if ((*ngslm.wordp(1)==sublm[i]->dict->oovcode()) &&
+ (dict->dub() > sublm[i]->dict->size()))
+ numer[i]/=(double)(dict->dub() - sublm[i]->dict->size());
+
+ denom+=numer[i];
+ }
+
+ for (int i=0; i<numslm; i++) {
+ l[lev][par][i]+=(ng.freq * (numer[i]/denom));
+ //if (iter>10)
+ //cout << ng << " l: " << l[lev][par][i] << "\n";
+ }
+ delete []numer;
+ }
+ }
+
+ //normalize all parameters
+ totalive=0;
+ for (int i=0; i<pmax; i++) {
+ double tot=0;
+ if (alive[i]) {
+ for (int j=0; j<numslm; j++) tot+=(l[lev][i][j]);
+ for (int j=0; j<numslm; j++) l[lev][i][j]/=tot;
+
+ //decide if to continue to update
+ if (!used[i] || (reldist(l[lev][i],oldl[i],numslm)<=0.05))
+ alive[i]=0;
+ }
+ totalive+=alive[i];
+ }
+
+ cerr << "Lev " << lev << " iter " << iter << " tot alive " << totalive << "\n";
+
+ }
+ }
+ }
+
+ if (opfname) savepar(opfname);
+
+
+ return 1;
+}
+
+int mixture::discount(ngram ng_,int size,double& fstar,double& lambda,int /* unused parameter: cv */)
+{
+
+ ngram ng(dict);
+ ng.trans(ng_);
+
+ double lambda2,fstar2;
+ fstar=0.0;
+ lambda=0.0;
+ int p=pmap(ng,size);
+ MY_ASSERT(p <= pmax);
+ double lsum=0;
+
+
+ for (int i=0; i<numslm; i++) {
+ sublm[i]->discount(ng,size,fstar2,lambda2,0);
+
+ ngram ngslm(sublm[i]->dict);
+ ngslm.trans(ng);
+
+ if (dict->dub() > sublm[i]->dict->size()){
+ if (*ngslm.wordp(1) == sublm[i]->dict->oovcode()) {
+ fstar2/=(double)(sublm[i]->dict->dub() - sublm[i]->dict->size()+1);
+ }
+ }
+
+
+ fstar+=(l[size][p][i]*fstar2);
+ lambda+=(l[size][p][i]*lambda2);
+ lsum+=l[size][p][i];
+ }
+
+ if (dict->dub() > dict->size())
+ if (*ng.wordp(1) == dict->oovcode()) {
+ fstar*=(double)(dict->dub() - dict->size()+1);
+ }
+
+ MY_ASSERT((lsum>LOWER_DOUBLE_PRECISION_OF_1) && (lsum<=UPPER_DOUBLE_PRECISION_OF_1));
+ return 1;
+}
+
+
+//creates the ngramtable on demand from the sublm tables
+int mixture::get(ngram& ng,int n,int lev)
+{
+
+ if (usefulltable)
+ {
+ return ngramtable::get(ng,n,lev);
+ }
+
+ //free current tree
+ resetngramtable();
+
+ //get 1-word prefix from ng
+ ngram ug(dict,1);
+ *ug.wordp(1)=*ng.wordp(ng.size);
+
+ //local ngram to upload entries
+ ngram locng(dict,maxlevel());
+
+ //allocate subtrees from sublm
+ for (int i=0; i<numslm; i++) {
+
+ ngram subug(sublm[i]->dict,1);
+ subug.trans(ug);
+
+ if (sublm[i]->get(subug,1,1)) {
+
+ ngram subng(sublm[i]->dict,maxlevel());
+ *subng.wordp(maxlevel())=*subug.wordp(1);
+ sublm[i]->scan(subug.link,subug.info,1,subng,INIT,maxlevel());
+ while(sublm[i]->scan(subug.link,subug.info,1,subng,CONT,maxlevel())) {
+ locng.trans(subng);
+ put(locng);
+ }
+ }
+ }
+
+ return ngramtable::get(ng,n,lev);
+
+}
+}//namespace irstlm
+
+
+
+
+
+
+
diff --git a/src/mixture.h b/src/mixture.h
new file mode 100644
index 0000000..d9774ee
--- /dev/null
+++ b/src/mixture.h
@@ -0,0 +1,95 @@
+/******************************************************************************
+IrstLM: IRST Language Model Toolkit
+Copyright (C) 2006 Marcello Federico, ITC-irst Trento, Italy
+
+This library is free software; you can redistribute it and/or
+modify it under the terms of the GNU Lesser General Public
+License as published by the Free Software Foundation; either
+version 2.1 of the License, or (at your option) any later version.
+
+This library is distributed in the hope that it will be useful,
+but WITHOUT ANY WARRANTY; without even the implied warranty of
+MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+Lesser General Public License for more details.
+
+You should have received a copy of the GNU Lesser General Public
+License along with this library; if not, write to the Free Software
+Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+
+******************************************************************************/
+
+// Mixture of linear interpolation LMs
+
+#ifndef LM_MIXTURE
+#define LM_MIXTURE
+
+namespace irstlm {
+
+class mixture: public mdiadaptlm
+{
+ double** l[MAX_NGRAM]; //interpolation parameters
+ int* pm; //parameter mappings
+ int pmax; //#parameters
+ int k1,k2; //two thresholds
+ int numslm;
+ int prunethresh;
+ interplm** sublm;
+ char *ipfname;
+ char *opfname;
+
+
+ double reldist(double *l1,double *l2,int n);
+ int genpmap();
+ int pmap(ngram ng,int lev);
+public:
+
+ bool usefulltable;
+
+ mixture(bool fulltable,char *sublminfo,int depth,int prunefreq=0,char* ipfile=NULL,char* opfile=NULL);
+
+ int train();
+
+ int savepar(char* opf);
+ int loadpar(char* opf);
+
+ inline int dub() {
+ return dict->dub();
+ }
+
+ inline int dub(int value) {
+ for (int i=0; i<numslm; i++) {
+ sublm[i]->dub(value);
+ }
+ return dict->dub(value);
+ }
+
+ void settying(int a,int b) {
+ k1=a;
+ k2=b;
+ }
+ int discount(ngram ng,int size,double& fstar,double& lambda,int cv=0);
+
+
+
+ ~mixture(){
+
+ for (int i=0;i<=lmsize();i++){
+ for (int j=0; j<pmax; j++) free(l[i][j]);
+ free(l[i]);
+ }
+
+ for (int i=0;i<numslm;i++) delete(sublm[i]);
+
+ }
+
+ //this extension builds a commong ngramtable on demand
+ int get(ngram& ng,int n,int lev);
+
+};
+
+}//namespace irstlm
+#endif
+
+
+
+
diff --git a/src/n_gram.cpp b/src/n_gram.cpp
new file mode 100644
index 0000000..276e6bc
--- /dev/null
+++ b/src/n_gram.cpp
@@ -0,0 +1,299 @@
+// $Id: n_gram.cpp 3461 2010-08-27 10:17:34Z bertoldi $
+
+/******************************************************************************
+ IrstLM: IRST Language Model Toolkit
+ Copyright (C) 2006 Marcello Federico, ITC-irst Trento, Italy
+
+ This library is free software; you can redistribute it and/or
+ modify it under the terms of the GNU Lesser General Public
+ License as published by the Free Software Foundation; either
+ version 2.1 of the License, or (at your option) any later version.
+
+ This library is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+ Lesser General Public License for more details.
+
+ You should have received a copy of the GNU Lesser General Public
+ License along with this library; if not, write to the Free Software
+ Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+
+ ******************************************************************************/
+
+
+#include <stdio.h>
+#include <cstdlib>
+#include <stdlib.h>
+#include <iomanip>
+#include <sstream>
+#include "util.h"
+#include "mempool.h"
+#include "htable.h"
+#include "dictionary.h"
+#include "n_gram.h"
+#include "index.h"
+
+using namespace std;
+
+ngram::ngram(dictionary* d,int sz)
+{
+ dict=d;
+ size=sz;
+ succ=0;
+ freq=0;
+ info=0;
+ pinfo=0;
+ link=NULL;
+ isym=-1;
+ memset(word,0,sizeof(int)*MAX_NGRAM);
+ memset(midx,0,sizeof(int)*MAX_NGRAM);
+ memset(path,0,sizeof(char *)*MAX_NGRAM);
+}
+
+ngram::ngram(ngram& ng)
+{
+ size=ng.size;
+ freq=ng.freq;
+ succ=0;
+ info=0;
+ pinfo=0;
+ link=NULL;
+ isym=-1;
+ dict=ng.dict;
+ memcpy(word,ng.word,sizeof(int)*MAX_NGRAM);
+ memcpy(midx,ng.word,sizeof(int)*MAX_NGRAM);
+}
+
+
+int ngram::containsWord(const char* s,int lev) {
+
+ int c=dict->encode(s);
+ if (c == -1) return 0;
+
+ MY_ASSERT(lev <= size);
+ for (int i=0; i<lev; i++) {
+ if (*wordp(size-i)== c) return 1;
+ }
+ return 0;
+}
+
+void ngram::trans (const ngram& ng)
+{
+ size=ng.size;
+ freq=ng.freq;
+ if (dict == ng.dict) {
+ info=ng.info;
+ isym=ng.isym;
+ memcpy(word,ng.word,sizeof(int)*MAX_NGRAM);
+ memcpy(midx,ng.midx,sizeof(int)*MAX_NGRAM);
+ } else {
+ info=0;
+ memset(midx,0,sizeof(int)*MAX_NGRAM);
+ isym=-1;
+ for (int i=1; i<=size; ++i)
+ word[MAX_NGRAM-i]=dict->encode(ng.dict->decode(*ng.wordp(i)));
+ }
+}
+
+
+void ngram::invert (const ngram& ng)
+{
+ size=ng.size;
+ for (int i=1; i<=size; i++) {
+ *wordp(i)=*ng.wordp(size-i+1);
+ }
+}
+
+void ngram::shift ()
+{
+ memmove((void *)&word[MAX_NGRAM-size+1],(void *)&word[MAX_NGRAM-size],(size-1) * sizeof(int));
+ size--;
+}
+
+void ngram::shift (int sz)
+{
+ if (sz>size) sz=size;
+ memmove((void *)&word[MAX_NGRAM-size+sz],(void *)&word[MAX_NGRAM-size],(size-sz) * sizeof(int));
+ size-=sz;
+}
+
+
+ifstream& operator>> ( ifstream& fi , ngram& ng)
+{
+ char w[MAX_WORD];
+ memset(w,0,MAX_WORD);
+ w[0]='\0';
+
+ if (!(fi >> setw(MAX_WORD) >> w))
+ return fi;
+
+ if (strlen(w)==(MAX_WORD-1))
+ cerr << "ngram: a too long word was read ("
+ << w << ")\n";
+
+ int c=ng.dict->encode(w);
+
+ if (c == -1 ) {
+ std::stringstream ss_msg;
+ ss_msg << "ngram: " << w << " is OOV";
+ exit_error(IRSTLM_ERROR_MODEL, ss_msg.str());
+ }
+
+ memcpy(ng.word,ng.word+1,(MAX_NGRAM-1)*sizeof(int));
+
+ ng.word[MAX_NGRAM-1]=(int)c;
+ ng.freq=1;
+
+ if (ng.size<MAX_NGRAM) ng.size++;
+
+ return fi;
+
+}
+
+
+int ngram::pushw(const char* w)
+{
+
+ MY_ASSERT(dict!=NULL);
+
+ int c=dict->encode(w);
+
+ if (c == -1 ) {
+ cerr << "ngram: " << w << " is OOV \n";
+ exit(1);
+ }
+
+ pushc(c);
+
+ return 1;
+
+}
+
+int ngram::pushc(int c)
+{
+
+ size++;
+ if (size>MAX_NGRAM) size=MAX_NGRAM;
+ size_t len = size - 1; //i.e. if size==MAX_NGRAM, the farthest position is lost
+ size_t src = MAX_NGRAM - len;
+
+ memmove((void *)&word[src - 1],(void *)&word[src], len * sizeof(int));
+
+ word[MAX_NGRAM-1]=c; // fill the most recent position
+
+ return 1;
+
+}
+
+int ngram::pushc(int* codes, int codes_len)
+{
+ //copy the first codes_len elements from codes into the actual ngram; sz must be smaller than MAX_NGRAM
+ //shift codes_len elements of the ngram backwards
+ MY_ASSERT (codes_len <= MAX_NGRAM);
+
+ size+=codes_len;
+
+ if (size>MAX_NGRAM) size=MAX_NGRAM;
+ size_t len = size - codes_len;
+ size_t src = MAX_NGRAM - len;
+
+ if (len > 0) memmove((void *)&word[src - codes_len],(void *)&word[src], len * sizeof(int));
+ memcpy((void *)&word[MAX_NGRAM - codes_len],(void*)&codes[0],codes_len*sizeof(int));
+
+ return 1;
+}
+
+int ngram::ckhisto(int sz) {
+
+ for (int i=sz; i>1; i--)
+ if (*wordp(i)==dict->oovcode())
+ return 0;
+ return 1;
+}
+
+
+
+bool ngram::operator==(const ngram &compare) const {
+ if ( size != compare.size || dict != compare.dict)
+ return false;
+ else
+ for (int i=size; i>0; i--)
+ if (word[MAX_NGRAM-i] != compare.word[MAX_NGRAM-i])
+ return false;
+ return true;
+}
+
+bool ngram::operator!=(const ngram &compare) const {
+ if ( size != compare.size || dict != compare.dict)
+ return true;
+ else
+ for (int i=size; i>0; i--)
+ if (word[MAX_NGRAM-i] != compare.word[MAX_NGRAM-i])
+ return true;
+ return false;
+}
+
+istream& operator>> ( istream& fi , ngram& ng)
+{
+ char w[MAX_WORD];
+ memset(w,0,MAX_WORD);
+ w[0]='\0';
+
+ MY_ASSERT(ng.dict != NULL);
+
+ if (!(fi >> setw(MAX_WORD) >> w))
+ return fi;
+
+ if (strlen(w)==(MAX_WORD-1))
+ cerr << "ngram: a too long word was read ("
+ << w << ")\n";
+
+ ng.pushw(w);
+
+ ng.freq=1;
+
+ return fi;
+
+}
+
+
+
+
+ofstream& operator<< (ofstream& fo,ngram& ng)
+{
+
+ MY_ASSERT(ng.dict != NULL);
+
+ for (int i=ng.size; i>0; i--)
+ fo << ng.dict->decode(ng.word[MAX_NGRAM-i]) << (i>1?" ":"");
+ fo << "\t" << ng.freq;
+ return fo;
+}
+
+ostream& operator<< (ostream& fo,ngram& ng)
+{
+
+ MY_ASSERT(ng.dict != NULL);
+
+ for (int i=ng.size; i>0; i--)
+ fo << ng.dict->decode(ng.word[MAX_NGRAM-i]) << (i>1?" ":"");
+ fo << "\t" << ng.freq;
+
+ return fo;
+}
+
+/*
+ main(int argc, char** argv){
+ dictionary d(argv[1]);
+ ifstream txt(argv[1]);
+ ngram ng(&d);
+
+ while (txt >> ng){
+ std::cout << ng << "\n";
+ }
+
+ ngram ng2=ng;
+ cerr << "copy last =" << ng << "\n";
+ }
+ */
+
diff --git a/src/n_gram.h b/src/n_gram.h
new file mode 100644
index 0000000..ccfa85c
--- /dev/null
+++ b/src/n_gram.h
@@ -0,0 +1,129 @@
+// $Id: n_gram.h 3461 2010-08-27 10:17:34Z bertoldi $
+
+/******************************************************************************
+ IrstLM: IRST Language Model Toolkit
+ Copyright (C) 2006 Marcello Federico, ITC-irst Trento, Italy
+
+ This library is free software; you can redistribute it and/or
+ modify it under the terms of the GNU Lesser General Public
+ License as published by the Free Software Foundation; either
+ version 2.1 of the License, or (at your option) any later version.
+
+ This library is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+ Lesser General Public License for more details.
+
+ You should have received a copy of the GNU Lesser General Public
+ License along with this library; if not, write to the Free Software
+ Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+
+******************************************************************************/
+
+// n-gram tables
+// by M. Federico
+// Copyright Marcello Federico, ITC-irst, 1998
+
+#ifndef MF_NGRAM_H
+#define MF_NGRAM_H
+
+#include <fstream>
+#include "util.h"
+#include "dictionary.h"
+
+#ifndef MYMAXNGRAM
+#define MYMAXNGRAM 20
+#endif
+#define MAX_NGRAM MYMAXNGRAM
+
+class dictionary;
+
+//typedef int code;
+
+class ngram
+{
+ int word[MAX_NGRAM]; //encoded ngram
+public:
+ dictionary *dict; // dictionary
+ char* link; // ngram-tree pointer
+ char* succlink; // pointer to the first successor
+ int midx[MAX_NGRAM]; // ngram-tree scan pointer
+ char* path[MAX_NGRAM]; // path in the ngram-trie
+ float bowv[MAX_NGRAM]; // vector of bow found in the trie
+
+ int lev; // ngram-tree level
+ int size; // ngram size
+ long long freq; // ngram frequency or integer prob
+ int succ; // number of successors
+ float bow; // back-off weight
+ float prob; // probability
+
+ unsigned char info; // ngram-tree info flags
+ unsigned char pinfo; // ngram-tree parent info flags
+ int isym; // last interruption symbol
+
+ ngram(dictionary* d,int sz=0);
+ ngram(ngram& ng);
+
+ inline int *wordp() { // n-gram pointer
+ return wordp(size);
+ }
+ inline int *wordp(int k) { // n-gram pointer
+ return size>=k?&word[MAX_NGRAM-k]:0;
+ }
+ inline const int *wordp() const { // n-gram pointer
+ return wordp(size);
+ }
+ inline const int *wordp(int k) const { // n-gram pointer
+ return size>=k?&word[MAX_NGRAM-k]:0;
+ }
+
+ int containsWord(const char* s,int lev);
+
+ void trans(const ngram& ng);
+ void invert (const ngram& ng);
+ void shift ();
+ void shift (int sz);
+
+ friend std::ifstream& operator>> (std::ifstream& fi,ngram& ng);
+ friend std::ofstream& operator<< (std::ofstream& fi,ngram& ng);
+ friend std::istream& operator>> (std::istream& fi,ngram& ng);
+ friend std::ostream& operator<< (std::ostream& fi,ngram& ng);
+
+ bool operator==(const ngram &compare) const;
+ bool operator!=(const ngram &compare) const;
+
+ /*
+ friend bool operator==(const ngram &compare) const {
+ if ( size != compare.size || dict != compare.dict)
+ return false;
+ else
+ for (int i=size; i>0; i--)
+ if (word[MAX_NGRAM-i] != compare.word[MAX_NGRAM-i])
+ return false;
+ return true;
+ }
+
+ inline bool operator!=(const ngram &compare) const {
+ if ( size != compare.size || dict != compare.dict)
+ return true;
+ else
+ for (int i=size; i>0; i--)
+ if (word[MAX_NGRAM-i] != compare.word[MAX_NGRAM-i])
+ return true;
+ return false;
+ }
+*/
+ int ckhisto(int sz);
+
+ int pushc(int c);
+ int pushc(int* codes, int sz);
+ int pushw(const char* w);
+
+ //~ngram();
+};
+
+#endif
+
+
+
diff --git a/src/ngramcache.cpp b/src/ngramcache.cpp
new file mode 100644
index 0000000..5f1f6a8
--- /dev/null
+++ b/src/ngramcache.cpp
@@ -0,0 +1,159 @@
+// $Id: ngramcache.cpp 3679 2010-10-13 09:10:01Z bertoldi $
+
+/******************************************************************************
+IrstLM: IRST Language Model Toolkit
+Copyright (C) 2006 Marcello Federico, ITC-irst Trento, Italy
+
+This library is free software; you can redistribute it and/or
+modify it under the terms of the GNU Lesser General Public
+License as published by the Free Software Foundation; either
+version 2.1 of the License, or (at your option) any later version.
+
+This library is distributed in the hope that it will be useful,
+but WITHOUT ANY WARRANTY; without even the implied warranty of
+MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+Lesser General Public License for more details.
+
+You should have received a copy of the GNU Lesser General Public
+License along with this library; if not, write to the Free Software
+Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+
+******************************************************************************/
+#include <iostream>
+#include <fstream>
+#include <stdexcept>
+#include <stdio.h>
+#include <stdlib.h>
+#include <cstring>
+#include <sstream>
+#include <string>
+#include "math.h"
+#include "mempool.h"
+#include "htable.h"
+#include "lmtable.h"
+#include "util.h"
+
+#include "ngramcache.h"
+
+using namespace std;
+
+void ngramcache::print (const int* ngp)
+{
+ std::cerr << "ngp: size:" << ngsize << "|";
+ for (int i=0; i<ngsize; i++)
+ std::cerr << " " << ngp[i];
+ std::cerr << " |\n";
+}
+
+ngramcache::ngramcache(int n,int size,int maxentries,float lf)
+{
+ if (lf<=0.0) lf=NGRAMCACHE_LOAD_FACTOR;
+ load_factor=lf;
+ ngsize=n;
+ infosize=size;
+ maxn=maxentries;
+ entries=0;
+ ht=new htable<int*>((size_t) (maxn/load_factor), ngsize * sizeof(int)); //decrease the lower load factor to reduce collision
+ mp=new mempool(ngsize * sizeof(int)+infosize,MP_BLOCK_SIZE);
+ accesses=0;
+ hits=0;
+};
+
+ngramcache::~ngramcache()
+{
+ delete ht;
+ delete mp;
+};
+
+
+//resize cache to specified number of entries
+void ngramcache::reset(int n)
+{
+ //ht->stat();
+ delete ht;
+ delete mp;
+ if (n>0) maxn=n;
+ ht=new htable<int*> ((size_t) (maxn/load_factor), ngsize * sizeof(int)); //decrease the lower load factor to reduce collision
+ mp=new mempool(ngsize * sizeof(int)+infosize,MP_BLOCK_SIZE);
+ entries=0;
+};
+
+char* ngramcache::get(const int* ngp,char*& info)
+{
+ char* found;
+
+ accesses++;
+ if ((found=(char*) ht->find((int *)ngp))) {
+ memcpy(&info,found+ngsize*sizeof(int),infosize);
+ hits++;
+ }
+
+ return found;
+};
+
+char* ngramcache::get(const int* ngp,double& info)
+{
+ char *found;
+
+ accesses++;
+ if ((found=(char*) ht->find((int *)ngp))) {
+ memcpy(&info,found+ngsize*sizeof(int),infosize);
+ hits++;
+ };
+
+ return found;
+};
+
+char* ngramcache::get(const int* ngp,prob_and_state_t& info)
+{
+ char *found;
+
+ accesses++;
+ if ((found=(char*) ht->find((int *)ngp)))
+ {
+ memcpy(&info,found+ngsize*sizeof(int),infosize);
+ hits++;
+ }
+ return found;
+};
+
+int ngramcache::add(const int* ngp,const char*& info)
+{
+ char* entry=mp->allocate();
+ memcpy(entry,(char*) ngp,sizeof(int) * ngsize);
+ memcpy(entry + ngsize * sizeof(int),&info,infosize);
+ char* found=(char*)ht->insert((int *)entry);
+ MY_ASSERT(found == entry); //false if key is already inside
+ entries++;
+ return 1;
+};
+
+int ngramcache::add(const int* ngp,const double& info)
+{
+ char* entry=mp->allocate();
+ memcpy(entry,(char*) ngp,sizeof(int) * ngsize);
+ memcpy(entry + ngsize * sizeof(int),&info,infosize);
+ char *found=(char*) ht->insert((int *)entry);
+ MY_ASSERT(found == entry); //false if key is already inside
+ entries++;
+ return 1;
+};
+
+int ngramcache::add(const int* ngp,const prob_and_state_t& info)
+{
+ char* entry=mp->allocate();
+ memcpy(entry,(char*) ngp,sizeof(int) * ngsize);
+ memcpy(entry + ngsize * sizeof(int),&info,infosize);
+ char *found=(char*) ht->insert((int *)entry);
+ MY_ASSERT(found == entry); //false if key is already inside
+ entries++;
+ return 1;
+};
+
+
+void ngramcache::stat() const
+{
+ std::cout << "ngramcache stats: entries=" << entries << " acc=" << accesses << " hits=" << hits
+ << " ht.used= " << ht->used() << " mp.used= " << mp->used() << " mp.wasted= " << mp->wasted() << "\n";
+};
+
diff --git a/src/ngramcache.h b/src/ngramcache.h
new file mode 100644
index 0000000..dc47952
--- /dev/null
+++ b/src/ngramcache.h
@@ -0,0 +1,93 @@
+// $Id: ngramcache.h 3679 2010-10-13 09:10:01Z bertoldi $
+
+/******************************************************************************
+IrstLM: IRST Language Model Toolkit
+Copyright (C) 2006 Marcello Federico, ITC-irst Trento, Italy
+
+This library is free software; you can redistribute it and/or
+modify it under the terms of the GNU Lesser General Public
+License as published by the Free Software Foundation; either
+version 2.1 of the License, or (at your option) any later version.
+
+This library is distributed in the hope that it will be useful,
+but WITHOUT ANY WARRANTY; without even the implied warranty of
+MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+Lesser General Public License for more details.
+
+You should have received a copy of the GNU Lesser General Public
+License along with this library; if not, write to the Free Software
+Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+
+******************************************************************************/
+
+#ifndef MF_NGRAMCACHE_H
+#define MF_NGRAMCACHE_H
+
+#include "mempool.h"
+#include "htable.h"
+
+#define NGRAMCACHE_t ngramcache
+
+#define NGRAMCACHE_LOAD_FACTOR 0.5
+
+typedef struct PROB_AND_STATE_ENTRY {
+ double logpr; //!< probability value of an ngram
+ char* state; //!< the largest suffix of an n-gram contained in the LM table.
+ unsigned int statesize; //!< LM statesize of an ngram
+ double bow; //!< backoff weight
+ int bol; //!< backoff level
+ bool extendible; //!< flag for extendibility of the ngram
+ PROB_AND_STATE_ENTRY(double lp=0.0, char* st=NULL, unsigned int stsz=0, double bw=0.0, int bl=0, bool ex=false): logpr(lp), state(st), statesize(stsz), bow(bw), bol(bl), extendible(ex) {}; //initializer
+} prob_and_state_t;
+
+void print(prob_and_state_t* pst, std::ostream& out=std::cout);
+
+class ngramcache
+{
+private:
+
+ static const bool debug=true;
+
+ htable<int*>* ht;
+ mempool *mp;
+ int maxn;
+ int ngsize;
+ int infosize;
+ int accesses;
+ int hits;
+ int entries;
+ float load_factor; //!< ngramcache loading factor
+ void print(const int*);
+
+public:
+ ngramcache(int n,int size,int maxentries,float lf=NGRAMCACHE_LOAD_FACTOR);
+ ~ngramcache();
+
+ inline int cursize() const {
+ return entries;
+ }
+ inline int maxsize() const {
+ return maxn;
+ }
+ void reset(int n=0);
+ char* get(const int* ngp,char*& info);
+ char* get(const int* ngp,double& info);
+ char* get(const int* ngp,prob_and_state_t& info);
+ int add(const int* ngp,const char*& info);
+ int add(const int* ngp,const double& info);
+ int add(const int* ngp,const prob_and_state_t& info);
+ inline int isfull() const {
+ return (entries >= maxn);
+ }
+ void stat() const;
+ inline void used() const {
+ stat();
+ };
+
+ inline float set_load_factor(float value) {
+ return load_factor=value;
+ }
+};
+
+#endif
+
diff --git a/src/ngramtable.cpp b/src/ngramtable.cpp
new file mode 100644
index 0000000..1779e90
--- /dev/null
+++ b/src/ngramtable.cpp
@@ -0,0 +1,1870 @@
+// $Id: ngramtable.cpp 35 2010-07-19 14:52:11Z nicolabertoldi $
+
+/******************************************************************************
+IrstLM: IRST Language Model Toolkit, compile LM
+Copyright (C) 2006 Marcello Federico, ITC-irst Trento, Italy
+
+This library is free software; you can redistribute it and/or
+modify it under the terms of the GNU Lesser General Public
+License as published by the Free Software Foundation; either
+version 2.1 of the License, or (at your option) any later version.
+
+This library is distributed in the hope that it will be useful,
+but WITHOUT ANY WARRANTY; without even the implied warranty of
+MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+Lesser General Public License for more details.
+
+You should have received a copy of the GNU Lesser General Public
+License along with this library; if not, write to the Free Software
+Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+
+******************************************************************************/
+
+#include <sstream>
+#include "util.h"
+#include "mfstream.h"
+#include "math.h"
+#include "mempool.h"
+#include "htable.h"
+#include "dictionary.h"
+#include "n_gram.h"
+#include "ngramtable.h"
+#include "crc.h"
+
+using namespace std;
+
+tabletype::tabletype(TABLETYPE tt,int codesize) {
+
+ if (codesize<=4 && codesize>0)
+ CODESIZE=codesize;
+ else {
+ exit_error(IRSTLM_ERROR_DATA,"ngramtable wrong codesize");
+ }
+
+ code_range[1]=255;
+ code_range[2]=65535;
+ code_range[3]=16777214;
+ code_range[4]=2147483640;
+ code_range[6]=140737488360000LL; //stay below true limit
+// code_range[6]=281474977000000LL; //stay below true limit
+
+ //information which is useful to initialize
+ //LEAFPROB tables
+ L_FREQ_SIZE=FREQ1;
+
+ WORD_OFFS =0;
+ MSUCC_OFFS =CODESIZE;
+ MTAB_OFFS =MSUCC_OFFS+CODESIZE;
+ FLAGS_OFFS =MTAB_OFFS+PTRSIZE;
+
+ switch (tt) {
+
+ case COUNT:
+ SUCC1_OFFS =0;
+ SUCC2_OFFS =0;
+ BOFF_OFFS =0;
+ I_FREQ_OFFS=FLAGS_OFFS+CHARSIZE;
+ I_FREQ_NUM=1;
+ L_FREQ_NUM=1;
+
+ ttype=tt;
+ break;
+
+ case FULL:
+ case IMPROVEDKNESERNEY_B:
+ case IMPROVEDSHIFTBETA_B:
+ SUCC1_OFFS =FLAGS_OFFS+CHARSIZE;
+ SUCC2_OFFS =SUCC1_OFFS+CODESIZE;
+ BOFF_OFFS =SUCC2_OFFS+CODESIZE;
+ I_FREQ_OFFS=BOFF_OFFS+INTSIZE;
+ L_FREQ_OFFS=CODESIZE;
+ I_FREQ_NUM=2;
+ L_FREQ_NUM=1;
+
+ ttype=tt;
+ break;
+
+ case IMPROVEDKNESERNEY_I:
+ case IMPROVEDSHIFTBETA_I:
+ SUCC1_OFFS =FLAGS_OFFS+CHARSIZE;
+ SUCC2_OFFS =SUCC1_OFFS+CODESIZE;
+ BOFF_OFFS =0;
+ I_FREQ_OFFS=SUCC2_OFFS+CODESIZE;
+ L_FREQ_OFFS=CODESIZE;
+ I_FREQ_NUM=2;
+ L_FREQ_NUM=1;
+
+ ttype=tt;
+ break;
+
+ case SIMPLE_I:
+ SUCC1_OFFS = 0;
+ SUCC2_OFFS = 0;
+ BOFF_OFFS = 0;
+ I_FREQ_OFFS= FLAGS_OFFS+CHARSIZE;
+ L_FREQ_OFFS=CODESIZE;
+ I_FREQ_NUM=1;
+ L_FREQ_NUM=1;
+
+ ttype=tt;
+ break;
+
+ case SIMPLE_B:
+ SUCC1_OFFS = 0;
+ SUCC2_OFFS = 0;
+ BOFF_OFFS = FLAGS_OFFS+CHARSIZE;
+ I_FREQ_OFFS = BOFF_OFFS+INTSIZE;
+ L_FREQ_OFFS = CODESIZE;
+ I_FREQ_NUM = 1;
+ L_FREQ_NUM = 1;
+
+ ttype=tt;
+ break;
+
+ case KNESERNEY_I:
+ case SHIFTBETA_I:
+ SUCC1_OFFS = FLAGS_OFFS+CHARSIZE;
+ SUCC2_OFFS = 0;
+ BOFF_OFFS = 0;
+ I_FREQ_OFFS= SUCC1_OFFS+CODESIZE;
+ L_FREQ_OFFS=CODESIZE;
+ I_FREQ_NUM=1;
+ L_FREQ_NUM=1;
+
+ ttype=tt;
+ break;
+
+ case KNESERNEY_B:
+ case SHIFTBETA_B:
+ SUCC1_OFFS = FLAGS_OFFS+CHARSIZE;
+ SUCC2_OFFS = 0;
+ BOFF_OFFS = SUCC1_OFFS+CODESIZE;
+ I_FREQ_OFFS = BOFF_OFFS+INTSIZE;
+ L_FREQ_OFFS = CODESIZE;
+ I_FREQ_NUM = 1;
+ L_FREQ_NUM = 1;
+
+ ttype=tt;
+ break;
+
+ case LEAFPROB:
+ case FLEAFPROB:
+ SUCC1_OFFS = 0;
+ SUCC2_OFFS = 0;
+ BOFF_OFFS = 0;
+ I_FREQ_OFFS = FLAGS_OFFS+CHARSIZE;
+ I_FREQ_NUM = 0;
+ L_FREQ_NUM = 1;
+
+ ttype=tt;
+ break;
+
+ case LEAFPROB2:
+ SUCC1_OFFS =0;
+ SUCC2_OFFS =0;
+ BOFF_OFFS =0;
+ I_FREQ_OFFS=FLAGS_OFFS+CHARSIZE;
+ I_FREQ_NUM=0;
+ L_FREQ_NUM=2;
+
+ ttype=LEAFPROB;
+ break;
+
+ case LEAFPROB3:
+ SUCC1_OFFS =0;
+ SUCC2_OFFS =0;
+ BOFF_OFFS =0;
+ I_FREQ_OFFS=FLAGS_OFFS+CHARSIZE;
+ I_FREQ_NUM=0;
+ L_FREQ_NUM=3;
+
+ ttype=LEAFPROB;
+ break;
+
+ case LEAFPROB4:
+ SUCC1_OFFS =0;
+ SUCC2_OFFS =0;
+ BOFF_OFFS =0;
+ I_FREQ_OFFS=FLAGS_OFFS+CHARSIZE;
+ I_FREQ_NUM=0;
+ L_FREQ_NUM=4;
+
+ ttype=LEAFPROB;
+ break;
+
+ default:
+ MY_ASSERT(tt==COUNT);
+ }
+
+ L_FREQ_OFFS=CODESIZE;
+};
+
+ngramtable::ngramtable(char* filename,int maxl,char* /* unused parameter: is */,
+ dictionary* extdict /* external dictionary */,char* filterdictfile,
+ int googletable,int dstco,char* hmask, int inplen,TABLETYPE ttype,
+ int codesize): tabletype(ttype,codesize)
+{
+
+ cerr << "[codesize " << CODESIZE << "]\n";
+ char header[100];
+
+ info[0]='\0';
+
+ corrcounts=0;
+
+ if (filename) {
+ int n;
+ mfstream inp(filename,ios::in );
+
+ inp >> header;
+
+ if (strncmp(header,"nGrAm",5)==0 || strncmp(header,"NgRaM",5)==0) {
+ inp >> n;
+ inp >> card;
+ inp >> info;
+ if (strcmp(info,"LM_")==0) {
+ inp >> resolution;
+ inp >> decay;
+ char info2[100];
+ sprintf(info2,"%s %d %f",info,resolution,decay);
+ strcpy(info, info2);
+ } else { //default for old LM probs
+ resolution=10000000;
+ decay=0.9999;
+ }
+
+ maxl=n; //owerwrite maxl
+
+ cerr << n << " " << card << " " << info << "\n";
+ }
+
+ inp.close();
+ }
+
+ if (!maxl) {
+ exit_error(IRSTLM_ERROR_DATA,"ngramtable: ngram size must be specified");
+ }
+
+ //distant co-occurreces works for bigrams and trigrams
+ if (dstco && (maxl!=2) && (maxl!=3)) {
+ exit_error(IRSTLM_ERROR_DATA,"distant co-occurrences work with 2-gram and 3-gram");
+ }
+
+ maxlev=maxl;
+
+ //Root not must have maximum frequency size
+
+ treeflags=INODE | FREQ6;
+ tree=(node) new char[inodesize(6)];
+ memset(tree,0,inodesize(6));
+
+
+ //1-gram table initial flags
+ if (maxlev>1)
+ mtflags(tree,INODE | FREQ4);
+ else if (maxlev==1)
+ mtflags(tree,LNODE | FREQ4);
+ else {
+ exit_error(IRSTLM_ERROR_DATA,"ngramtable: wrong level setting");
+ }
+
+ word(tree,0); // dummy variable
+
+ if (I_FREQ_NUM)
+ freq(tree,treeflags,0); // frequency of all n-grams
+
+ msucc(tree,0); // number of different n-grams
+ mtable(tree,NULL); // table of n-gram
+
+ mem=new storage(256,10000);
+
+ mentr=new long long[maxlev+1];
+ memory= new long long[maxlev+1];
+ occupancy= new long long[maxlev+1];
+
+//Book keeping of occupied memory
+ mentr[0]=1;
+ memory[0]=inodesize(6); // root is an inode with highest frequency
+ occupancy[0]=inodesize(6); // root is an inode with highest frequency
+
+ for (int i=1; i<=maxlev; i++)
+ mentr[i]=memory[i]=occupancy[i]=0;
+
+ dict=new dictionary(NULL,1000000);
+
+ if (!filename) return ;
+
+ filterdict=NULL;
+ if (filterdictfile) {
+ filterdict=new dictionary(filterdictfile,1000000);
+ /*
+ filterdict->incflag(1);
+ filterdict->encode(BOS_);
+ filterdict->encode(EOS_);
+ filterdict->incflag(0);
+ */
+ }
+
+ // switch to specific loading methods
+
+ if ((strncmp(header,"ngram",5)==0) ||
+ (strncmp(header,"NGRAM",5)==0)) {
+ exit_error(IRSTLM_ERROR_DATA,"this ngram file format is no more supported");
+ }
+
+ if (strncmp(header,"nGrAm",5)==0)
+ loadtxt(filename);
+ else if (strncmp(header,"NgRaM",5)==0)
+ loadbin(filename);
+ else if (dstco>0)
+ generate_dstco(filename,dstco);
+ else if (hmask != NULL)
+ generate_hmask(filename,hmask,inplen);
+ else if (googletable)
+ loadtxt(filename,googletable);
+ else
+ generate(filename,extdict);
+
+
+
+ if (tbtype()==LEAFPROB) {
+ du_code=dict->encode(DUMMY_);
+ bo_code=dict->encode(BACKOFF_);
+ }
+}
+
+void ngramtable::savetxt(char *filename,int depth,bool googleformat,bool hashvalue,int startfrom)
+{
+ char ngstring[10000];
+
+ if (depth>maxlev) {
+ exit_error(IRSTLM_ERROR_DATA,"ngramtable::savetxt: wrong n-gram size");
+ }
+
+ if (startfrom>0 && !googleformat) {
+ exit_error(IRSTLM_ERROR_DATA,
+ "ngramtable::savetxt: multilevel output only allowed in googleformat");
+ }
+
+ depth=(depth>0?depth:maxlev);
+
+ card=mentr[depth];
+
+ ngram ng(dict);
+
+ if (googleformat)
+ cerr << "savetxt in Google format: nGrAm " << depth << " " << card << " " << info << "\n";
+ else
+ cerr << "savetxt: nGrAm " << depth << " " << card << " " << info << "\n";
+
+ mfstream out(filename,ios::out );
+
+ if (!googleformat){
+ out << "nGrAm " << depth << " " << card << " " << info << "\n";
+ dict->save(out);
+ }
+
+ if (startfrom<=0 || startfrom > depth) startfrom=depth;
+
+ for (int d=startfrom;d<=depth;d++){
+ scan(ng,INIT,d);
+
+ while(scan(ng,CONT,d)){
+
+ if (hashvalue){
+ strcpy(ngstring,ng.dict->decode(*ng.wordp(ng.size)));
+ for (int i=ng.size-1; i>0; i--){
+ strcat(ngstring," ");
+ strcat(ngstring,ng.dict->decode(*ng.wordp(i)));
+ }
+ out << ngstring << "\t" << ng.freq << "\t" << crc16_ccitt(ngstring,strlen(ngstring)) << "\n";
+ }
+ else
+
+ out << ng << "\n";
+
+
+ }
+ }
+ cerr << "\n";
+
+ out.close();
+}
+
+
+void ngramtable::loadtxt(char *filename,int googletable)
+{
+
+ ngram ng(dict);;
+
+ cerr << "loadtxt:" << (googletable?"google format":"std table");
+
+ mfstream inp(filename,ios::in);
+
+ int i,c=0;
+
+ if (googletable) {
+ dict->incflag(1);
+ } else {
+ char header[100];
+ inp.getline(header,100);
+ cerr << header ;
+ dict->load(inp);
+ }
+
+ while (!inp.eof()) {
+
+ for (i=0; i<maxlev; i++) inp >> ng;
+
+ inp >> ng.freq;
+
+ if (ng.size==0) continue;
+
+ //update dictionary frequency when loading from
+ if (googletable) dict->incfreq(*ng.wordp(1),ng.freq);
+
+ // if filtering dictionary exists
+ // and if the first word of the ngram does not belong to it
+ // do not insert the ngram
+
+ if (filterdict) {
+ int code=filterdict->encode(dict->decode(*ng.wordp(maxlev)));
+ if (code!=filterdict->oovcode()) put(ng);
+ } else put(ng);
+
+ ng.size=0;
+
+ if (!(++c % 1000000)) cerr << ".";
+
+ }
+
+ if (googletable) {
+ dict->incflag(0);
+ }
+
+ cerr << "\n";
+
+ inp.close();
+}
+
+
+
+void ngramtable::savebin(mfstream& out,node nd,NODETYPE ndt,int lev,int mlev)
+{
+
+ out.write(nd+WORD_OFFS,CODESIZE);
+
+ //write frequency
+
+ int offs=(ndt & LNODE)?L_FREQ_OFFS:I_FREQ_OFFS;
+
+ int frnum=1;
+ if (tbtype()==LEAFPROB && (ndt & LNODE))
+ frnum=L_FREQ_NUM;
+
+ if ((ndt & LNODE) || I_FREQ_NUM) { //check if to write freq
+ if (ndt & FREQ1)
+ out.write(nd+offs,1 * frnum);
+ else if (ndt & FREQ2)
+ out.write(nd+offs,2 * frnum);
+ else if (ndt & FREQ3)
+ out.write(nd+offs,3 * frnum);
+ else
+ out.write(nd+offs,INTSIZE * frnum);
+ }
+
+ if ((lev <mlev) && (ndt & INODE)) {
+
+ unsigned char fl=mtflags(nd);
+ if (lev==(mlev-1))
+ //transforms flags into a leaf node
+ fl=(fl & ~INODE) | LNODE;
+
+ out.write((const char*) &fl,CHARSIZE);
+ fl=mtflags(nd);
+
+ out.write(nd+MSUCC_OFFS,CODESIZE);
+
+ int msz=mtablesz(nd);
+ int m=msucc(nd);
+
+ for (int i=0; i<m; i++)
+ savebin(out,mtable(nd) + i * msz,fl,lev+1,mlev);
+ }
+}
+
+
+void ngramtable::savebin(mfstream& out)
+{
+
+ int depth=maxlev;
+
+ card=mentr[depth];
+
+ cerr << "ngramtable::savebin ";
+
+ out.writex((char *)&depth,INTSIZE);
+
+ out.write((char *)&treeflags,CHARSIZE);
+
+ savebin(out,tree,treeflags,0,depth);
+
+ cerr << "\n";
+}
+
+
+void ngramtable::savebin(char *filename,int depth)
+{
+
+ if (depth > maxlev) {
+ exit_error(IRSTLM_ERROR_DATA,"ngramtable::savebin: wrong n-gram size");
+ }
+
+ depth=(depth>0?depth:maxlev);
+
+ card=mentr[depth];
+
+ cerr << "savebin NgRaM " << depth << " " << card;
+
+ mfstream out(filename,ios::out );
+
+ if (dict->oovcode()!=-1) //there are OOV words
+ out << "NgRaM_ " << depth << " " << card << " " << info << "\n";
+ else
+ out << "NgRaM " << depth << " " << card << " " << info << "\n";
+
+ dict->save(out);
+
+ out.writex((char *)&depth,INTSIZE);
+
+ out.write((char *)&treeflags,CHARSIZE);
+
+ savebin(out,tree,treeflags,0,depth);
+
+ out.close();
+
+ cerr << "\n";
+}
+
+
+void ngramtable::loadbin(mfstream& inp,node nd,NODETYPE ndt,int lev)
+{
+ static int c=0;
+
+ // read code
+ inp.read(nd+WORD_OFFS,CODESIZE);
+
+ // read frequency
+ int offs=(ndt & LNODE)?L_FREQ_OFFS:I_FREQ_OFFS;
+
+ int frnum=1;
+ if (tbtype()==LEAFPROB && (ndt & LNODE))
+ frnum=L_FREQ_NUM;
+
+ if ((ndt & LNODE) || I_FREQ_NUM) { //check if to read freq
+ if (ndt & FREQ1)
+ inp.read(nd+offs,1 * frnum);
+ else if (ndt & FREQ2)
+ inp.read(nd+offs,2 * frnum);
+ else if (ndt & FREQ3)
+ inp.read(nd+offs,3 * frnum);
+ else
+ inp.read(nd+offs,4 * frnum);
+ }
+
+ if (ndt & INODE) {
+
+ //read flags
+ inp.read(nd+FLAGS_OFFS,CHARSIZE);
+ unsigned char fl=mtflags(nd);
+
+ //read #of multiple entries
+ inp.read(nd+MSUCC_OFFS,CODESIZE);
+ int m=msucc(nd);
+
+ if (m>0) {
+ //read multiple entries
+ int msz=mtablesz(nd);
+ table mtb=mtable(nd);
+ //table entries increase
+ grow(&mtb,INODE,lev+1,m,msz);
+
+ for (int i=0; i<m; i++)
+ loadbin(inp,mtb + i * msz,fl,lev+1);
+
+ mtable(nd,mtb);
+ }
+
+ mentr[lev+1]+=m;
+ occupancy[lev+1]+=(m * mtablesz(nd));
+
+ } else if (!(++c % 1000000)) cerr << ".";
+
+}
+
+
+
+void ngramtable::loadbin(mfstream& inp)
+{
+
+ cerr << "loadbin ";
+
+ inp.readx((char *)&maxlev,INTSIZE);
+ inp.read((char *)&treeflags,CHARSIZE);
+
+ loadbin(inp,tree,treeflags,0);
+
+ cerr << "\n";
+}
+
+
+void ngramtable::loadbin(const char *filename)
+{
+
+ cerr << "loadbin ";
+ mfstream inp(filename,ios::in );
+
+ //skip header
+ char header[100];
+ inp.getline(header,100);
+
+ cerr << header ;
+
+ dict->load(inp);
+
+ inp.readx((char *)&maxlev,INTSIZE);
+ inp.read((char *)&treeflags,CHARSIZE);
+
+ loadbin(inp,tree,treeflags,0);
+
+ inp.close();
+
+ cerr << "\n";
+}
+
+
+void ngramtable::generate(char *filename, dictionary* extdict)
+{
+ mfstream inp(filename,ios::in);
+ int i,c=0;
+
+ if (!inp) {
+ std::stringstream ss_msg;
+ ss_msg << "cannot open " << filename;
+ exit_error(IRSTLM_ERROR_IO, ss_msg.str());
+ }
+
+ cerr << "load:";
+
+ ngram ng(extdict==NULL?dict:extdict); //use possible prescribed dictionary
+ if (extdict) dict->genoovcode();
+
+ ngram ng2(dict);
+ dict->incflag(1);
+
+ cerr << "prepare initial n-grams to make table consistent\n";
+ for (i=1; i<maxlev; i++) {
+ ng.pushw(dict->BoS());
+ ng.freq=1;
+ };
+
+ while (inp >> ng) {
+
+ if (ng.size>maxlev) ng.size=maxlev; //speeds up
+
+ ng2.trans(ng); //reencode with new dictionary
+
+ check_dictsize_bound();
+
+ if (ng2.size) dict->incfreq(*ng2.wordp(1),1);
+
+ // if filtering dictionary exists
+ // and if the first word of the ngram does not belong to it
+ // do not insert the ngram
+ if (filterdict) {
+ int code=filterdict->encode(dict->decode(*ng2.wordp(maxlev)));
+ if (code!=filterdict->oovcode()) put(ng2);
+ } else put(ng2);
+
+ if (!(++c % 1000000)) cerr << ".";
+
+ }
+
+ cerr << "adding some more n-grams to make table consistent\n";
+ for (i=1; i<=maxlev; i++) {
+ ng2.pushw(dict->BoS());
+ ng2.freq=1;
+
+ // if filtering dictionary exists
+ // and if the first word of the ngram does not belong to it
+ // do not insert the ngram
+ if (filterdict) {
+ int code=filterdict->encode(dict->decode(*ng2.wordp(maxlev)));
+ if (code!=filterdict->oovcode()) put(ng2);
+ } else put(ng2);
+ };
+
+ dict->incflag(0);
+ inp.close();
+ strcpy(info,"ngram");
+
+ cerr << "\n";
+}
+
+void ngramtable::generate_hmask(char *filename,char* hmask,int inplen)
+{
+ mfstream inp(filename,ios::in);
+
+ if (!inp) {
+ std::stringstream ss_msg;
+ ss_msg << "cannot open " << filename;
+ exit_error(IRSTLM_ERROR_IO, ss_msg.str());
+ }
+
+ int selmask[MAX_NGRAM];
+ memset(selmask, 0, sizeof(int)*MAX_NGRAM);
+
+ //parse hmask
+ selmask[0]=1;
+ int i=1;
+ for (size_t c=0; c<strlen(hmask); c++) {
+ cerr << hmask[c] << "\n";
+ if (hmask[c] == '1'){
+ selmask[i]=c+2;
+ i++;
+ }
+ }
+ if (i!= maxlev) {
+ std::stringstream ss_msg;
+ ss_msg << "wrong mask: 1 bits=" << i << " maxlev=" << maxlev;
+ exit_error(IRSTLM_ERROR_DATA, ss_msg.str());
+ }
+
+ cerr << "load:";
+
+ ngram ng(dict);
+ ngram ng2(dict);
+ dict->incflag(1);
+ long c=0;
+ while (inp >> ng) {
+
+ if (inplen && ng.size<inplen) continue;
+
+ ng2.trans(ng); //reencode with new dictionary
+ ng.size=0; //reset ng
+
+ if (ng2.size >= selmask[maxlev-1]) {
+ for (int j=0; j<maxlev; j++)
+ *ng2.wordp(j+1)=*ng2.wordp(selmask[i]);
+
+ //cout << ng2 << "size:" << ng2.size << "\n";
+ check_dictsize_bound();
+
+ put(ng2);
+ }
+
+ if (ng2.size) dict->incfreq(*ng2.wordp(1),1);
+
+ if (!(++c % 1000000)) cerr << ".";
+ };
+
+ dict->incflag(0);
+ inp.close();
+ sprintf(info,"hm%s\n",hmask);
+
+ cerr << "\n";
+}
+
+int cmpint(const void *a,const void *b)
+{
+ return (*(int *)b)-(*(int *)a);
+}
+
+void ngramtable::generate_dstco(char *filename,int dstco)
+{
+ mfstream inp(filename,ios::in);
+ int c=0;
+
+ if (!inp) {
+ std::stringstream ss_msg;
+ ss_msg << "cannot open " << filename;
+ exit_error(IRSTLM_ERROR_IO, ss_msg.str());
+ }
+
+ cerr << "load distant co-occurrences:";
+ if (dstco>MAX_NGRAM) {
+ inp.close();
+ std::stringstream ss_msg;
+ ss_msg << "window size (" << dstco << ") exceeds MAXNGRAM";
+ exit_error(IRSTLM_ERROR_DATA, ss_msg.str());
+ }
+
+ ngram ng(dict);
+ ngram ng2(dict);
+ ngram dng(dict);
+ dict->incflag(1);
+
+ while (inp >> ng) {
+ if (ng.size) {
+
+ ng2.trans(ng); //reencode with new dictionary
+
+ if (ng2.size>dstco) ng2.size=dstco; //maximum distance
+
+ check_dictsize_bound();
+
+ dict->incfreq(*ng2.wordp(1),1);
+
+ if (maxlev == 1 )
+ cerr << "maxlev is wrong! (Possible values are 2 or 3)\n";
+
+ else if (maxlev == 2 ) { //maxlev ==2
+ dng.size=2;
+ dng.freq=1;
+
+ //cerr << "size=" << ng2.size << "\n";
+
+ for (int i=2; i<=ng2.size; i++) {
+
+ if (*ng2.wordp(1)<*ng2.wordp(i)) {
+ *dng.wordp(2)=*ng2.wordp(i);
+ *dng.wordp(1)=*ng2.wordp(1);
+ } else {
+ *dng.wordp(1)=*ng2.wordp(i);
+ *dng.wordp(2)=*ng2.wordp(1);
+ }
+ //cerr << dng << "\n";
+ put(dng);
+ }
+ if (!(++c % 1000000)) cerr << ".";
+ } else { //maxlev ==3
+ dng.size=3;
+ dng.freq=1;
+
+ //cerr << "size=" << ng2.size << "\n";
+ int ar[3];
+
+ ar[0]=*ng2.wordp(1);
+ for (int i=2; i<ng2.size; i++) {
+ ar[1]=*ng2.wordp(i);
+ for (int j=i+1; j<=ng2.size; j++) {
+ ar[2]=*ng2.wordp(j);
+
+ //sort ar
+ qsort(ar,3,sizeof(int),cmpint);
+
+ *dng.wordp(1)=ar[0];
+ *dng.wordp(2)=ar[1];
+ *dng.wordp(3)=ar[2];
+
+ // cerr << ng2 << "\n";
+ //cerr << dng << "\n";
+ //cerr << *dng.wordp(1) << " "
+ // << *dng.wordp(2) << " "
+ // << *dng.wordp(3) << "\n";
+ put(dng);
+ }
+ }
+ }
+ }
+ }
+ dict->incflag(0);
+ inp.close();
+ sprintf(info,"co-occ%d\n",dstco);
+ cerr << "\n";
+}
+
+
+
+void ngramtable::augment(ngramtable* ngt)
+{
+
+ if (ngt->maxlev != maxlev) {
+ exit_error(IRSTLM_ERROR_DATA,"ngramtable::augment augmentation is not possible due to table incompatibility");
+ }
+
+ if (ngt->dict->oovcode()!=-1)
+ cerr <<"oov: " << ngt->dict->freq(ngt->dict->oovcode()) << "\n";
+ cerr <<"size: " << ngt->dict->size() << "\n";
+
+ if (dict->oovcode()!=-1)
+ cerr <<"oov: " << dict->freq(dict->oovcode()) << "\n";
+ cerr <<"size: " << dict->size() << "\n";
+
+
+ dict->incflag(1);
+ cerr << "augmenting ngram table\n";
+ ngram ng1(ngt->dict);
+ ngram ng2(dict);
+ ngt->scan(ng1,INIT);
+ int c=0;
+ while (ngt->scan(ng1,CONT)) {
+ ng2.trans(ng1);
+ put(ng2);
+ if ((++c % 1000000) ==0) cerr <<".";
+ }
+ cerr << "\n";
+
+ for (int i=0; i<ngt->dict->size(); i++)
+ dict->incfreq(dict->encode(ngt->dict->decode(i)),
+ ngt->dict->freq(i));
+
+ dict->incflag(0);
+
+ int oov=dict->getcode(dict->OOV());
+
+ if (oov>=0) {
+ dict->oovcode(oov);
+ }
+
+ cerr << "oov: " << dict->freq(dict->oovcode()) << "\n";
+ cerr << "size: " << dict->size() << "\n";
+}
+
+void ngramtable::show()
+{
+
+ ngram ng(dict);
+
+ scan(ng,INIT);
+ cout << "Stampo contenuto della tabella\n";
+ while (scan(ng)) {
+ cout << ng << "\n";
+ }
+}
+
+
+
+int ngramtable::mybsearch(char *ar, int n, int size, unsigned char *key, int *idx)
+{
+ if (n==0) return 0;
+
+ register int low = 0, high = n;
+ *idx=0;
+ register unsigned char *p=NULL;
+ int result;
+
+#ifdef INTERP_SEARCH
+ char* lp;
+ char* hp;
+#endif
+
+ /* return idx with the first
+ position equal or greater than key */
+
+ /* Warning("start bsearch \n"); */
+
+
+ while (low < high) {
+
+
+#ifdef INTERP_SEARCH
+ //use interpolation search only for intervals with at least 4096 entries
+
+ if ((high-low)>=10000) {
+
+ lp=(char *) (ar + (low * size));
+ if (codecmp((char *)key,lp)<0) {
+ *idx=low;
+ return 0;
+ }
+
+ hp=(char *) (ar + ((high-1) * size));
+ if (codecmp((char *)key,hp)>0) {
+ *idx=high;
+ return 0;
+ }
+
+ *idx= low + ((high-1)-low) * codediff((char *)key,lp)/codediff(hp,(char *)lp);
+ } else
+#endif
+ *idx = (low + high) / 2;
+
+ //after redefining the interval there is no guarantee
+ //that wlp <= wkey <= whigh
+
+ p = (unsigned char *) (ar + (*idx * size));
+ result=codecmp((char *)key,(char *)p);
+
+ if (result < 0) {
+ high = *idx;
+ } else if (result > 0) {
+ low = ++(*idx);
+ } else
+ return 1;
+ }
+
+ *idx=low;
+
+ return 0;
+
+}
+
+void *ngramtable::search(table *tb,NODETYPE ndt,int lev,int n,int sz,int *ngp,
+ ACTION action,char **found)
+{
+
+ char w[CODESIZE];
+ putmem(w,ngp[0],0,CODESIZE);
+ int wint=ngp[0];
+
+
+ // index returned by mybsearch
+
+ if (found) *found=NULL;
+
+ int idx=0;
+
+ switch(action) {
+
+ case ENTER:
+
+ if (!*tb ||
+ !mybsearch(*tb,n,sz,(unsigned char *)w,&idx)) {
+ // let possibly grow the table
+ grow(tb,ndt,lev,n,sz); // devo aggiungere un elemento n+1
+
+ //shift table by one
+
+ memmove(*tb + (idx+1) * sz,
+ *tb + idx * sz,
+ (n-idx) * sz);
+
+ memset(*tb + idx * sz , 0 , sz);
+
+ word(*tb + idx * sz, wint);
+
+ } else if (found) *found=*tb + ( idx * sz );
+
+ return *tb + ( idx * sz );
+
+ break;
+
+
+ case FIND:
+
+ if (!*tb ||
+ !mybsearch(*tb,n,sz,(unsigned char *)w,&idx))
+ return 0;
+ else if (found) *found=*tb + (idx * sz);
+
+ return *tb + (idx * sz);
+
+ break;
+
+ case DELETE:
+
+ if (*tb &&
+ mybsearch(*tb,n,sz,(unsigned char *)w,&idx)) {
+ //shift table down by one
+
+ static char buffer[100];
+
+ memcpy(buffer,*tb + idx * sz , sz);
+
+ if (idx <(n-1))
+ memmove(*tb + idx * sz,
+ *tb + (idx + 1) * sz,
+ (n-idx-1) * sz);
+
+ //put the deleted item after the last item
+
+ memcpy(*tb + (n-1) * sz , buffer , sz);
+
+ if (found) *found=*tb + (n-1) * sz ;
+
+ return *tb + (n-1) * sz ;
+
+ } else
+
+ return NULL;
+
+ break;
+
+ default:
+ cerr << "this option is not implemented yet\n";
+ break;
+ }
+
+ return NULL;
+
+}
+
+int ngramtable::comptbsize(int n)
+{
+
+ if (n>16384)
+ return(n/16384)*16384+(n % 16384?16384:0);
+ else if (n>8192) return 16384;
+ else if (n>4096) return 8192;
+ else if (n>2048) return 4096;
+ else if (n>1024) return 2048;
+ else if (n>512) return 1024;
+ else if (n>256) return 512;
+ else if (n>128) return 256;
+ else if (n>64) return 128;
+ else if (n>32) return 64;
+ else if (n>16) return 32;
+ else if (n>8) return 16;
+ else if (n>4) return 8;
+ else if (n>2) return 4;
+ else if (n>1) return 2;
+ else return 1;
+
+}
+
+
+char **ngramtable::grow(table *tb,NODETYPE ndt,int lev,
+ int n,int sz,NODETYPE oldndt)
+{
+ int inc;
+ int num;
+
+ //memory pools for inode/lnode tables
+
+ if (oldndt==0) {
+
+ if ((*tb==NULL) && n>0) {
+ // n is the target number of entries
+ //first allocation
+
+ if (n>16384)
+ inc=(n/16384)*16384+(n % 16384?16384:0);
+ else if (n>8192) inc=16384;
+ else if (n>4096) inc=8192;
+ else if (n>2048) inc=4096;
+ else if (n>1024) inc=2048;
+ else if (n>512) inc=1024;
+ else if (n>256) inc=512;
+ else if (n>128) inc=256;
+ else if (n>64) inc=128;
+ else if (n>32) inc=64;
+ else if (n>16) inc=32;
+ else if (n>8) inc=16;
+ else if (n>4) inc=8;
+ else if (n>2) inc=4;
+ else if (n>1) inc=2;
+ else inc=1;
+
+ n=0; //inc is the correct target size
+
+ }
+
+ else {
+ // table will be extended on demand
+ // I'm sure that one entry will be
+ // added next
+
+ // check multiples of 1024
+ if ((n>=16384) && !(n % 16384)) inc=16384;
+ else {
+ switch (n) {
+ case 0:
+ inc=1;
+ break;
+ case 1:
+ case 2:
+ case 4:
+ case 8:
+ case 16:
+ case 32:
+ case 64:
+ case 128:
+ case 256:
+ case 512:
+ case 1024:
+ case 2048:
+ case 4096:
+ case 8192:
+ inc=n;
+ break;
+ default:
+ return tb;
+ }
+ }
+ }
+
+ table ntb=(char *)mem->reallocate(*tb,n * sz,(n + inc) * sz);
+
+ memory[lev]+= (inc * sz);
+
+ *tb=ntb;
+ }
+
+ else {
+ //change frequency type of table
+ //no entries will be added now
+
+ int oldsz=0;
+
+ // guess the current memory size !!!!
+ num=comptbsize(n);
+
+ if ((ndt & INODE) && I_FREQ_NUM) {
+ if (oldndt & FREQ1)
+ oldsz=inodesize(1);
+ else if (oldndt & FREQ2)
+ oldsz=inodesize(2);
+ else if (oldndt & FREQ3)
+ oldsz=inodesize(3);
+ else if (oldndt & FREQ4)
+ oldsz=inodesize(4);
+ else {
+ exit_error(IRSTLM_ERROR_DATA,"ngramtable::grow functionality not available");
+ }
+ } else if (ndt & LNODE) {
+ if (oldndt & FREQ1)
+ oldsz=lnodesize(1);
+ else if (oldndt & FREQ2)
+ oldsz=lnodesize(2);
+ else if (oldndt & FREQ3)
+ oldsz=lnodesize(3);
+ else if (oldndt & FREQ4)
+ oldsz=lnodesize(4);
+ else {
+ exit_error(IRSTLM_ERROR_DATA,"ngramtable::grow functionality not available");
+ }
+ }
+
+ table ntb=(char *)mem->allocate(num * sz);
+ memset((char *)ntb,0,num * sz);
+
+ if (ndt & INODE)
+ for (int i=0; i<n; i++) {
+ word(ntb+i*sz,word(*tb+i*oldsz));
+ msucc(ntb+i*sz,msucc(*tb+i*oldsz));
+ mtflags(ntb+i*sz,mtflags(*tb+i*oldsz));
+ mtable(ntb+i*sz,mtable(*tb+i*oldsz));
+ for (int j=0; j<I_FREQ_NUM; j++)
+ setfreq(ntb+i*sz,ndt,getfreq(*tb+i*oldsz,oldndt,j),j);
+ }
+ else
+ for (int i=0; i<n; i++) {
+ word(ntb+i*sz,word(*tb+i*oldsz));
+ for (int j=0; j<L_FREQ_NUM; j++)
+ setfreq(ntb+i*sz,ndt,getfreq(*tb+i*oldsz,oldndt,j),j);
+ }
+
+ mem->free(*tb,num * oldsz); //num is the correct size
+ memory[lev]+=num * (sz - oldsz);
+ occupancy[lev]+=n * (sz - oldsz);
+
+ *tb=ntb;
+ }
+
+ return tb;
+
+};
+
+
+int ngramtable::put(ngram& ng)
+{
+
+ return ngramtable::put(ng,tree,treeflags,0);
+
+}
+
+int ngramtable::put(ngram& ng,node nd,NODETYPE ndt,int lev)
+{
+ char *found;
+ node subnd;
+
+ if (ng.size<maxlev) return 0;
+
+
+ /*
+ cerr << "l:" << lev << " put:" << ng << "\n";
+ cerr << "I_FREQ_NUM: " << I_FREQ_NUM << "\n";
+ cerr << "LNODE: " << (int) LNODE << "\n";
+ cerr << "ndt: " << (int) ndt << "\n";
+ */
+
+ for (int l=lev; l<maxlev; l++) {
+
+ if (I_FREQ_NUM || (ndt & LNODE))
+ freq(nd,ndt,freq(nd,ndt) + ng.freq);
+
+ table mtb=mtable(nd);
+
+ // it has to be added to the multiple table
+
+ subnd=(char *)
+ search(&mtb,
+ mtflags(nd),
+ l+1,
+ msucc(nd),
+ mtablesz(nd),
+ ng.wordp(maxlev-l),
+ ENTER,&found);
+
+ if (!found) { //a new element has been added
+
+ msucc(nd,msucc(nd)+1);
+
+ mentr[l+1]++;
+ occupancy[l+1]+=mtablesz(nd);
+
+ unsigned char freq_flag;
+ if (I_FREQ_NUM)
+ //tree with internal freqs must
+ //be never expanded during usage
+ //of the secondary frequencies
+ freq_flag=(ng.freq>65535?FREQ4:FREQ1);
+ else
+ //all leafprob with L_FREQ_NUM >=1
+ //do NOT have INTERNAL freqs
+ //will have freq size specified
+ //by the resolution parameter
+ //to avoid expansion
+ freq_flag=L_FREQ_SIZE;
+
+ if ((l+1)<maxlev) { //update mtable flags
+ if ((l+2)<maxlev)
+ mtflags(subnd,INODE | freq_flag);
+ else
+ mtflags(subnd,LNODE | freq_flag);
+
+ }
+ }
+
+ // ... go on with the subtree
+
+ // check if we must extend the subnode
+
+ NODETYPE oldndt=mtflags(nd);
+
+ if ((I_FREQ_NUM || (mtflags(nd) & LNODE)) &&
+ (mtflags(nd) & FREQ1) &&
+ ((freq(subnd,mtflags(nd))+ng.freq)>255))
+
+ mtflags(nd,(mtflags(nd) & ~FREQ1) | FREQ2); //update flags
+
+
+ if ((I_FREQ_NUM || (mtflags(nd) & LNODE)) &&
+ (mtflags(nd) & FREQ2) &&
+ ((freq(subnd,mtflags(nd))+ng.freq)>65535))
+
+ mtflags(nd,(mtflags(nd) & ~FREQ2) | FREQ3); //update flags
+
+
+ if ((I_FREQ_NUM || (mtflags(nd) & LNODE)) &&
+ (mtflags(nd) & FREQ3) &&
+ ((freq(subnd,mtflags(nd))+ng.freq)>16777215))
+
+ mtflags(nd,(mtflags(nd) & ~FREQ3) | FREQ4); //update flags
+
+ if ((I_FREQ_NUM || (mtflags(nd) & LNODE)) &&
+ (mtflags(nd) & FREQ4) &&
+ ((freq(subnd,mtflags(nd))+ng.freq)>4294967295LL))
+
+ mtflags(nd,(mtflags(nd) & ~FREQ4) | FREQ6); //update flags
+
+ if (mtflags(nd)!=oldndt) {
+ // flags have changed, table has to be expanded
+ //expand subtable
+ cerr << "+"<<l+1;
+ //table entries remain the same
+ grow(&mtb,mtflags(nd),l+1,msucc(nd),mtablesz(nd),oldndt);
+ cerr << "\b\b";
+ //update subnode
+ subnd=(char *)
+ search(&mtb,
+ mtflags(nd),
+ l+1,
+ msucc(nd),
+ mtablesz(nd),
+ ng.wordp(maxlev-l),
+ FIND,&found);
+ }
+
+
+ mtable(nd,mtb);
+ ndt=mtflags(nd);
+ nd=subnd;
+ }
+
+ freq(nd, ndt, freq(nd,ndt) + ng.freq);
+
+ return 1;
+}
+
+
+
+int ngramtable::get(ngram& ng,int n,int lev)
+{
+
+ node nd,subnd;
+ char *found;
+ NODETYPE ndt;
+
+ MY_ASSERT(lev <= n && lev <= maxlev && ng.size >= n);
+
+ if ((I_FREQ_NUM==0) && (lev < maxlev)) {
+ exit_error(IRSTLM_ERROR_DATA,"ngramtable::get for this type of table ngram cannot be smaller than table size");
+ }
+
+
+ if (ng.wordp(n)) {
+
+ nd=tree;
+ ndt=treeflags;
+
+ for (int l=0; l<lev; l++) {
+
+ table mtb=mtable(nd);
+
+ subnd=(char *)
+ search(&mtb,
+ mtflags(nd),
+ l+1,
+ msucc(nd),
+ mtablesz(nd),
+ ng.wordp(n-l),
+ FIND,&found);
+
+ ndt=mtflags(nd);
+ nd=subnd;
+
+ if (nd==0) return 0;
+ }
+
+ ng.size=n;
+ ng.freq=freq(nd,ndt);
+ ng.link=nd;
+ ng.lev=lev;
+ ng.pinfo=ndt; //parent node info
+
+ if (lev<maxlev) {
+ ng.succ=msucc(nd);
+ ng.info=mtflags(nd);
+ } else {
+ ng.succ=0;
+ ng.info=LNODE;
+ }
+ return 1;
+ }
+ return 0;
+}
+
+
+int ngramtable::scan(node nd,NODETYPE /* unused parameter: ndt */,int lev,ngram& ng,ACTION action,int maxl)
+{
+
+ MY_ASSERT(lev<=maxlev);
+
+ if ((I_FREQ_NUM==0) && (maxl < maxlev)) {
+ exit_error(IRSTLM_ERROR_MODEL,"ngramtable::scan ngram cannot be smaller than LEAFPROB table");
+ }
+
+
+ if (maxl==-1) maxl=maxlev;
+
+ ng.size=maxl;
+
+ switch (action) {
+
+
+ case INIT:
+ //reset ngram local indexes
+
+ for (int l=0; l<=maxlev; l++) ng.midx[l]=0;
+
+ return 1;
+
+ case CONT:
+
+ if (lev>(maxl-1)) return 0;
+
+ if (ng.midx[lev]<msucc(nd)) {
+ //put current word into ng
+ *ng.wordp(maxl-lev)=
+ word(mtable(nd)+ng.midx[lev] * mtablesz(nd));
+
+ //inspect subtree
+ //check if there is something left in the tree
+
+ if (lev<(maxl-1)) {
+ if (scan(mtable(nd) + ng.midx[lev] * mtablesz(nd),
+ INODE,
+ lev+1,ng,CONT,maxl))
+ return 1;
+ else {
+ ng.midx[lev]++; //go to next
+ for (int l=lev+1; l<=maxlev; l++) ng.midx[l]=0; //reset indexes
+
+ return scan(nd,INODE,lev,ng,CONT,maxl); //restart scanning
+ }
+ } else {
+ // put data into the n-gram
+
+ *ng.wordp(maxl-lev)=
+ word(mtable(nd)+ng.midx[lev] * mtablesz(nd));
+
+ ng.freq=freq(mtable(nd)+ ng.midx[lev] * mtablesz(nd),mtflags(nd));
+ ng.pinfo=mtflags(nd);
+
+ if (maxl<maxlev) {
+ ng.info=mtflags(mtable(nd)+ ng.midx[lev] * mtablesz(nd));
+ ng.link=mtable(nd)+ng.midx[lev] * mtablesz(nd); //link to the node
+ ng.succ=msucc(mtable(nd)+ ng.midx[lev] * mtablesz(nd));
+ } else {
+ ng.info=LNODE;
+ ng.link=NULL;
+ ng.succ=0;
+ }
+
+ ng.midx[lev]++;
+
+ return 1;
+ }
+ } else
+ return 0;
+
+ default:
+ cerr << "scan: not supported action\n";
+ break;
+
+ }
+ return 0;
+}
+
+
+void ngramtable::freetree(node nd)
+{
+ int m=msucc(nd);
+ int msz=mtablesz(nd);
+ int truem=comptbsize(m);
+
+ if (mtflags(nd) & INODE)
+ for (int i=0; i<m; i++)
+ freetree(mtable(nd) + i * msz);
+ mem->free(mtable(nd),msz*truem);
+}
+
+
+ngramtable::~ngramtable()
+{
+ freetree(tree);
+ delete [] tree;
+ delete mem;
+ delete [] memory;
+ delete [] occupancy;
+ delete [] mentr;
+ delete dict;
+};
+
+void ngramtable::stat(int level)
+{
+ long long totmem=0;
+ long long totwaste=0;
+ float mega=1024 * 1024;
+
+ cout.precision(2);
+
+ cout << "ngramtable class statistics\n";
+
+ cout << "levels " << maxlev << "\n";
+ for (int l=0; l<=maxlev; l++) {
+ cout << "lev " << l
+ << " entries "<< mentr[l]
+ << " allocated mem " << memory[l]/mega << "Mb "
+ << " used mem " << occupancy[l]/mega << "Mb \n";
+ totmem+=memory[l];
+ totwaste+=(memory[l]-occupancy[l]);
+ }
+
+ cout << "total allocated mem " << totmem/mega << "Mb ";
+ cout << "wasted mem " << totwaste/mega << "Mb\n\n\n";
+
+ if (level >1 ) dict->stat();
+
+ cout << "\n\n";
+
+ if (level >2) mem->stat();
+
+}
+
+
+double ngramtable::prob(ngram ong)
+{
+
+ if (ong.size==0) return 0.0;
+ if (ong.size>maxlev) ong.size=maxlev;
+
+ MY_ASSERT(tbtype()==LEAFPROB && ong.size<=maxlev);
+
+ ngram ng(dict);
+ ng.trans(ong);
+
+ double bo;
+
+ ng.size=maxlev;
+ for (int s=ong.size+1; s<=maxlev; s++)
+ *ng.wordp(s)=du_code;
+
+ if (get(ng)) {
+
+ if (ong.size>1 && resolution<10000000)
+ return (double)pow(decay,(resolution-ng.freq));
+ else
+ return (double)(ng.freq+1)/10000000.0;
+
+ } else { // backoff-probability
+
+ bo_state(1); //set backoff state to 1
+
+ *ng.wordp(1)=bo_code;
+
+ if (get(ng))
+
+ bo=resolution<10000000
+ ?(double)pow(decay,(resolution-ng.freq))
+ :(double)(ng.freq+1)/10000000.0;
+
+ else
+ bo=1.0;
+
+ ong.size--;
+
+ return bo * prob(ong);
+ }
+}
+
+
+bool ngramtable::check_dictsize_bound()
+{
+ if (dict->size() >= code_range[CODESIZE]) {
+ std::stringstream ss_msg;
+ ss_msg << "dictionary size overflows code range " << code_range[CODESIZE];
+ exit_error(IRSTLM_ERROR_MODEL, ss_msg.str());
+ }
+ return true;
+}
+
+int ngramtable::update(ngram ng) {
+
+ if (!get(ng,ng.size,ng.size)) {
+ std::stringstream ss_msg;
+ ss_msg << "cannot find " << ng;
+ exit_error(IRSTLM_ERROR_MODEL, ss_msg.str());
+ }
+
+ freq(ng.link,ng.pinfo,ng.freq);
+
+ return 1;
+}
+
+void ngramtable::resetngramtable() {
+ //clean up all memory and restart from an empty table
+
+ freetree(); //clean memory pool
+ memset(tree,0,inodesize(6)); //reset tree
+ //1-gram table initial flags
+
+ if (maxlev>1) mtflags(tree,INODE | FREQ4);
+ else if (maxlev==1) mtflags(tree,LNODE | FREQ4);
+
+ word(tree,0); //dummy word
+ msucc(tree,0); // number of different n-grams
+ mtable(tree,NULL); // table of n-gram
+
+ for (int i=1; i<=maxlev; i++)
+ mentr[i]=memory[i]=occupancy[i]=0;
+
+}
+
+int ngramtable::putmem(char* ptr,int value,int offs,int size) {
+ MY_ASSERT(ptr!=NULL);
+ for (int i=0; i<size; i++)
+ ptr[offs+i]=(value >> (8 * i)) & 0xff;
+ return value;
+}
+
+int ngramtable::getmem(char* ptr,int* value,int offs,int size) {
+ MY_ASSERT(ptr!=NULL);
+ *value=ptr[offs] & 0xff;
+ for (int i=1; i<size; i++)
+ *value= *value | ( ( ptr[offs+i] & 0xff ) << (8 *i));
+ return *value;
+}
+
+long ngramtable::putmem(char* ptr,long long value,int offs,int size) {
+ MY_ASSERT(ptr!=NULL);
+ for (int i=0; i<size; i++)
+ ptr[offs+i]=(value >> (8 * i)) & 0xffLL;
+ return value;
+}
+
+long ngramtable::getmem(char* ptr,long long* value,int offs,int size) {
+ MY_ASSERT(ptr!=NULL);
+ *value=ptr[offs] & 0xff;
+ for (int i=1; i<size; i++)
+ *value= *value | ( ( ptr[offs+i] & 0xffLL ) << (8 *i));
+ return *value;
+}
+
+void ngramtable::tb2ngcpy(int* wordp,char* tablep,int n) {
+ for (int i=0; i<n; i++)
+ getmem(tablep,&wordp[i],i*CODESIZE,CODESIZE);
+}
+
+void ngramtable::ng2tbcpy(char* tablep,int* wordp,int n) {
+ for (int i=0; i<n; i++)
+ putmem(tablep,wordp[i],i*CODESIZE,CODESIZE);
+}
+
+int ngramtable::ngtbcmp(int* wordp,char* tablep,int n) {
+ int word;
+ for (int i=0; i<n; i++) {
+ getmem(tablep,&word,i*CODESIZE,CODESIZE);
+ if (wordp[i]!=word) return 1;
+ }
+ return 0;
+}
+
+int ngramtable::codecmp(char * a,char *b) {
+ register int i,result;
+ for (i=(CODESIZE-1); i>=0; i--) {
+ result=(unsigned char)a[i]-(unsigned char)b[i];
+ if(result) return result;
+ }
+ return 0;
+};
+
+long long ngramtable::freq(node nd,NODETYPE ndt,long long value) {
+ int offs=(ndt & LNODE)?L_FREQ_OFFS:I_FREQ_OFFS;
+
+ if (ndt & FREQ1)
+ putmem(nd,value,offs,1);
+ else if (ndt & FREQ2)
+ putmem(nd,value,offs,2);
+ else if (ndt & FREQ3)
+ putmem(nd,value,offs,3);
+ else if (ndt & FREQ4)
+ putmem(nd,value,offs,4);
+ else
+ putmem(nd,value,offs,6);
+ return value;
+}
+
+long long ngramtable::freq(node nd,NODETYPE ndt) {
+ int offs=(ndt & LNODE)?L_FREQ_OFFS:I_FREQ_OFFS;
+ long long value;
+
+ if (ndt & FREQ1)
+ getmem(nd,&value,offs,1);
+ else if (ndt & FREQ2)
+ getmem(nd,&value,offs,2);
+ else if (ndt & FREQ3)
+ getmem(nd,&value,offs,3);
+ else if (ndt & FREQ4)
+ getmem(nd,&value,offs,4);
+ else
+ getmem(nd,&value,offs,6);
+
+ return value;
+}
+
+
+long long ngramtable::setfreq(node nd,NODETYPE ndt,long long value,int index) {
+ int offs=(ndt & LNODE)?L_FREQ_OFFS:I_FREQ_OFFS;
+
+ if (ndt & FREQ1)
+ putmem(nd,value,offs+index * 1,1);
+ else if (ndt & FREQ2)
+ putmem(nd,value,offs+index * 2,2);
+ else if (ndt & FREQ3)
+ putmem(nd,value,offs+index * 3,3);
+ else if (ndt & FREQ4)
+ putmem(nd,value,offs+index * 4,4);
+ else
+ putmem(nd,value,offs+index * 6,6);
+
+ return value;
+}
+
+long long ngramtable::getfreq(node nd,NODETYPE ndt,int index) {
+ int offs=(ndt & LNODE)?L_FREQ_OFFS:I_FREQ_OFFS;
+
+ long long value;
+
+ if (ndt & FREQ1)
+ getmem(nd,&value,offs+ index * 1,1);
+ else if (ndt & FREQ2)
+ getmem(nd,&value,offs+ index * 2,2);
+ else if (ndt & FREQ3)
+ getmem(nd,&value,offs+ index * 3,3);
+ else if (ndt & FREQ4)
+ getmem(nd,&value,offs+ index * 4,4);
+ else
+ getmem(nd,&value,offs+ index * 6,6);
+
+ return value;
+}
+
+table ngramtable::mtable(node nd) {
+ char v[PTRSIZE];;
+ for (int i=0; i<PTRSIZE; i++)
+ v[i]=nd[MTAB_OFFS+i];
+
+ return *(table *)v;
+}
+
+table ngramtable::mtable(node nd,table value) {
+ char *v=(char *)&value;
+ for (int i=0; i<PTRSIZE; i++)
+ nd[MTAB_OFFS+i]=v[i];
+ return value;
+}
+
+int ngramtable::mtablesz(node nd) {
+ if (mtflags(nd) & LNODE) {
+ if (mtflags(nd) & FREQ1)
+ return lnodesize(1);
+ else if (mtflags(nd) & FREQ2)
+ return lnodesize(2);
+ else if (mtflags(nd) & FREQ3)
+ return lnodesize(3);
+ else if (mtflags(nd) & FREQ4)
+ return lnodesize(4);
+ else
+ return lnodesize(6);
+ } else if (mtflags(nd) & INODE) {
+ if (mtflags(nd) & FREQ1)
+ return inodesize(1);
+ else if (mtflags(nd) & FREQ2)
+ return inodesize(2);
+ else if (mtflags(nd) & FREQ3)
+ return inodesize(3);
+ else if (mtflags(nd) & FREQ4)
+ return inodesize(4);
+ else
+ return inodesize(6);
+ } else {
+ exit_error(IRSTLM_ERROR_DATA,"ngramtable::mtablesz node has wrong flags");
+ }
+
+ return lnodesize(1); //this instruction is never reached
+}
+
+
+/*
+ main(int argc, char** argv){
+ dictionary d(argv[1]);
+
+ ngram ng(&d);
+
+ cerr << "caricato dizionario da " << argv[1] << "\n";
+
+ ngramtable t(&d,argv[2],1);
+
+ t.stat(1);
+ t.savetxt(argv[3]);
+
+ }
+*/
+
+
+
+
+
diff --git a/src/ngramtable.h b/src/ngramtable.h
new file mode 100644
index 0000000..1286a12
--- /dev/null
+++ b/src/ngramtable.h
@@ -0,0 +1,379 @@
+// $Id: ngramtable.h 34 2010-06-03 09:19:34Z nicolabertoldi $
+
+/******************************************************************************
+IrstLM: IRST Language Model Toolkit, compile LM
+Copyright (C) 2006 Marcello Federico, ITC-irst Trento, Italy
+
+This library is free software; you can redistribute it and/or
+modify it under the terms of the GNU Lesser General Public
+License as published by the Free Software Foundation; either
+version 2.1 of the License, or (at your option) any later version.
+
+This library is distributed in the hope that it will be useful,
+but WITHOUT ANY WARRANTY; without even the implied warranty of
+MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+Lesser General Public License for more details.
+
+You should have received a copy of the GNU Lesser General Public
+License along with this library; if not, write to the Free Software
+Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+
+******************************************************************************/
+
+#ifndef MF_NGRAMTABLE_H
+#define MF_NGRAMTABLE_H
+
+#include "n_gram.h"
+
+//Backoff symbol
+#ifndef BACKOFF_
+#define BACKOFF_ "_backoff_"
+#endif
+
+//Dummy symbol
+#ifndef DUMMY_
+#define DUMMY_ "_dummy_"
+#endif
+
+// internal data structure
+
+#ifdef MYCODESIZE
+#define DEFCODESIZE MYCODESIZE
+#else
+#define DEFCODESIZE (int)2
+#endif
+
+#define SHORTSIZE (int)2
+#define PTRSIZE (int)sizeof(char *)
+#define INTSIZE (int)4
+#define CHARSIZE (int)1
+
+
+//Node flags
+#define FREQ1 (unsigned char) 1
+#define FREQ2 (unsigned char) 2
+#define FREQ4 (unsigned char) 4
+#define INODE (unsigned char) 8
+#define LNODE (unsigned char) 16
+#define SNODE (unsigned char) 32
+#define FREQ6 (unsigned char) 64
+#define FREQ3 (unsigned char) 128
+
+typedef char* node; //inodes, lnodes, snodes
+typedef char* table; //inode table, lnode table, singleton table
+
+typedef unsigned char NODETYPE;
+
+
+typedef enum {FIND, //!< search: find an entry
+ ENTER, //!< search: enter an entry
+ DELETE, //!< search: find and remove entry
+ INIT, //!< scan: start scan
+ CONT //!< scan: continue scan
+ } ACTION;
+
+
+typedef enum {COUNT, //!< table: only counters
+ LEAFPROB, //!< table: only probs on leafs
+ FLEAFPROB, //!< table: only probs on leafs and FROZEN
+ LEAFPROB2, //!< table: only probs on leafs
+ LEAFPROB3, //!< table: only probs on leafs
+ LEAFPROB4, //!< table: only probs on leafs
+ LEAFCODE, //!< table: only codes on leafs
+ SIMPLE_I, //!< table: simple interpolated LM
+ SIMPLE_B, //!< table: simple backoff LM
+ SHIFTBETA_I, //!< table: interpolated shiftbeta
+ SHIFTBETA_B, //!< table: backoff shiftbeta
+ IMPROVEDSHIFTBETA_I,//!< table: interp improved shiftbeta
+ IMPROVEDSHIFTBETA_B,//!< table: interp improved shiftbeta
+ KNESERNEY_I,//!< table: interp kneser-ney
+ KNESERNEY_B,//!< table: backoff kneser-ney
+ IMPROVEDKNESERNEY_I,//!< table: interp improved kneser-ney
+ IMPROVEDKNESERNEY_B,//!< table: backoff improved kneser-ney
+ FULL, //!< table: full fledged table
+
+ } TABLETYPE;
+
+class tabletype
+{
+
+ TABLETYPE ttype;
+
+public:
+
+ int CODESIZE; //sizeof word codes
+ long long code_range[7]; //max code for each size
+
+ //Offsets of internal node fields
+ int WORD_OFFS; //word code position
+ int MSUCC_OFFS; //number of successors
+ int MTAB_OFFS; //pointer to successors
+ int FLAGS_OFFS; //flag table
+ int SUCC1_OFFS; //number of successors with freq=1
+ int SUCC2_OFFS; //number of successors with freq=2
+ int BOFF_OFFS; //back-off probability
+ int I_FREQ_OFFS; //frequency offset
+ int I_FREQ_NUM; //number of internal frequencies
+ int L_FREQ_NUM; //number of leaf frequencies
+ int L_FREQ_SIZE; //minimum size for leaf frequencies
+
+ //Offsets of leaf node fields
+ int L_FREQ_OFFS; //frequency offset
+
+ tabletype(TABLETYPE tt,int codesize=DEFCODESIZE);
+
+ inline TABLETYPE tbtype() const {
+ return ttype;
+ }
+ inline int inodesize(int s) const {
+ return I_FREQ_OFFS + I_FREQ_NUM * s;
+ }
+
+ inline int lnodesize(int s) const {
+ return L_FREQ_OFFS + L_FREQ_NUM * s;
+ }
+
+};
+
+
+class ngramtable:tabletype
+{
+
+ node tree; // ngram table root
+ int maxlev; // max storable n-gram
+ NODETYPE treeflags;
+ char info[100]; //information put in the header
+ int resolution; //max resolution for probabilities
+ double decay; //decay constant
+
+ storage* mem; //memory storage class
+
+ long long* memory; // memory load per level
+ long long* occupancy; // memory occupied per level
+ long long* mentr; // multiple entries per level
+ long long card; //entries at maxlev
+
+ int idx[MAX_NGRAM+1];
+
+ int oov_code,oov_size,du_code, bo_code; //used by prob
+
+ int backoff_state; //used by prob;
+
+public:
+
+ int corrcounts; //corrected counters flag
+
+ dictionary *dict; // dictionary
+
+ // filtering dictionary:
+ // if the first word of the ngram does not belong to filterdict
+ // do not insert the ngram
+ dictionary *filterdict;
+
+ ngramtable(char* filename,int maxl,char* is,
+ dictionary* extdict,
+ char* filterdictfile,
+ int googletable=0,
+ int dstco=0,char* hmask=NULL,int inplen=0,
+ TABLETYPE tt=FULL,int codesize=DEFCODESIZE);
+
+ inline char* ngtype(char *str=NULL) {
+ if (str!=NULL) strcpy(info,str);
+ return info;
+ }
+
+ virtual ~ngramtable();
+
+ inline void freetree() {
+ freetree(tree);
+ };
+
+ void freetree(node nd);
+
+ void resetngramtable();
+
+ void stat(int level=4);
+
+ inline long long totfreq(long long v=-1) {
+ return (v==-1?freq(tree,INODE):freq(tree,INODE,v));
+ }
+
+ inline long long btotfreq(long long v=-1) {
+ return (v==-1?getfreq(tree,treeflags,1):setfreq(tree,treeflags,v,1));
+ }
+
+ inline long long entries(int lev) const {
+ return mentr[lev];
+ }
+
+ inline int maxlevel() const {
+ return maxlev;
+ }
+
+ // void savetxt(char *filename,int sz=0);
+ void savetxt(char *filename,int sz=0,bool googleformat=false,bool hashvalue=false,int startfrom=0);
+ void loadtxt(char *filename,int googletable=0);
+
+ void savebin(char *filename,int sz=0);
+ void savebin(mfstream& out);
+ void savebin(mfstream& out,node nd,NODETYPE ndt,int lev,int mlev);
+
+ void loadbin(const char *filename);
+ void loadbin(mfstream& inp);
+ void loadbin(mfstream& inp,node nd,NODETYPE ndt,int lev);
+
+ void loadbinold(char *filename);
+ void loadbinold(mfstream& inp,node nd,NODETYPE ndt,int lev);
+
+ void generate(char *filename,dictionary *extdict=NULL);
+ void generate_dstco(char *filename,int dstco);
+ void generate_hmask(char *filename,char* hmask,int inplen=0);
+
+ void augment(ngramtable* ngt);
+
+ inline int scan(ngram& ng,ACTION action=CONT,int maxlev=-1) {
+ return scan(tree,INODE,0,ng,action,maxlev);
+ }
+
+ inline int succscan(ngram& h,ngram& ng,ACTION action,int lev) {
+ //return scan(h.link,h.info,h.lev,ng,action,lev);
+ return scan(h.link,h.info,lev-1,ng,action,lev);
+ }
+
+ double prob(ngram ng);
+
+ int scan(node nd,NODETYPE ndt,int lev,ngram& ng,ACTION action=CONT,int maxl=-1);
+
+ void show();
+
+ void *search(table *tb,NODETYPE ndt,int lev,int n,int sz,int *w,
+ ACTION action,char **found=(char **)NULL);
+
+ int mybsearch(char *ar, int n, int size, unsigned char *key, int *idx);
+
+ int put(ngram& ng);
+ int put(ngram& ng,node nd,NODETYPE ndt,int lev);
+
+ inline int get(ngram& ng) {
+ return get(ng,maxlev,maxlev);
+ }
+ virtual int get(ngram& ng,int n,int lev);
+
+ int comptbsize(int n);
+ table *grow(table *tb,NODETYPE ndt,int lev,int n,int sz,NODETYPE oldndt=0);
+
+ bool check_dictsize_bound();
+
+ int putmem(char* ptr,int value,int offs,int size);
+ int getmem(char* ptr,int* value,int offs,int size);
+ long putmem(char* ptr,long long value,int offs,int size);
+ long getmem(char* ptr,long long* value,int offs,int size);
+
+ inline void tb2ngcpy(int* wordp,char* tablep,int n=1);
+ inline void ng2tbcpy(char* tablep,int* wordp,int n=1);
+ inline int ngtbcmp(int* wordp,char* tablep,int n=1);
+
+ inline int word(node nd,int value) {
+ putmem(nd,value,WORD_OFFS,CODESIZE);
+ return value;
+ }
+
+ inline int word(node nd) {
+ int v;
+ getmem(nd,&v,WORD_OFFS,CODESIZE);
+ return v;
+ }
+
+ inline unsigned char mtflags(node nd,unsigned char value) {
+ return *(unsigned char *)(nd+FLAGS_OFFS)=value;
+ }
+
+ inline unsigned char mtflags(node nd) const {
+ return *(unsigned char *)(nd+FLAGS_OFFS);
+ }
+
+ int codecmp(char * a,char *b);
+
+ inline int codediff(node a,node b) {
+ return word(a)-word(b);
+ };
+
+
+ int update(ngram ng);
+
+ long long freq(node nd,NODETYPE ndt,long long value);
+ long long freq(node nd,NODETYPE ndt);
+
+ long long setfreq(node nd,NODETYPE ndt,long long value,int index=0);
+ long long getfreq(node nd,NODETYPE ndt,int index=0);
+
+ double boff(node nd) {
+ int value=0;
+ getmem(nd,&value,BOFF_OFFS,INTSIZE);
+
+ return double (value/(double)1000000000.0);
+ }
+
+
+ double myround(double x) {
+ long int i=(long int)(x);
+ return (x-i)>0.500?i+1.0:(double)i;
+ }
+
+ int boff(node nd,double value) {
+ int v=(int)myround(value * 1000000000.0);
+ putmem(nd,v,BOFF_OFFS,INTSIZE);
+
+ return 1;
+ }
+
+ int succ2(node nd,int value) {
+ putmem(nd,value,SUCC2_OFFS,CODESIZE);
+ return value;
+ }
+
+ int succ2(node nd) {
+ int value=0;
+ getmem(nd,&value,SUCC2_OFFS,CODESIZE);
+ return value;
+ }
+
+ int succ1(node nd,int value) {
+ putmem(nd,value,SUCC1_OFFS,CODESIZE);
+ return value;
+ }
+
+ int succ1(node nd) {
+ int value=0;
+ getmem(nd,&value,SUCC1_OFFS,CODESIZE);
+ return value;
+ }
+
+ int msucc(node nd,int value) {
+ putmem(nd,value,MSUCC_OFFS,CODESIZE);
+ return value;
+ }
+
+ int msucc(node nd) {
+ int value;
+ getmem(nd,&value,MSUCC_OFFS,CODESIZE);
+ return value;
+ }
+
+ table mtable(node nd);
+ table mtable(node nd,table value);
+ int mtablesz(node nd);
+
+ inline int bo_state() {
+ return backoff_state;
+ }
+ inline int bo_state(int value) {
+ return backoff_state=value;
+ }
+};
+
+#endif
+
+
+
+
diff --git a/src/ngt.cpp b/src/ngt.cpp
new file mode 100644
index 0000000..511e693
--- /dev/null
+++ b/src/ngt.cpp
@@ -0,0 +1,506 @@
+// $Id: ngt.cpp 245 2009-04-02 14:05:40Z fabio_brugnara $
+
+/******************************************************************************
+IrstLM: IRST Language Model Toolkit
+Copyright (C) 2006 Marcello Federico, ITC-irst Trento, Italy
+
+This library is free software; you can redistribute it and/or
+modify it under the terms of the GNU Lesser General Public
+License as published by the Free Software Foundation; either
+version 2.1 of the License, or (at your option) any later version.
+
+This library is distributed in the hope that it will be useful,
+but WITHOUT ANY WARRANTY; without even the implied warranty of
+MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+Lesser General Public License for more details.
+
+You should have received a copy of the GNU Lesser General Public
+License along with this library; if not, write to the Free Software
+Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+
+******************************************************************************/
+
+// ngt
+// by M. Federico
+// Copyright Marcello Federico, ITC-irst, 1998
+
+
+#include <iostream>
+#include <sstream>
+#include <cmath>
+#include "util.h"
+#include "cmd.h"
+#include "mfstream.h"
+#include "mempool.h"
+#include "htable.h"
+#include "dictionary.h"
+#include "n_gram.h"
+#include "ngramtable.h"
+
+using namespace std;
+
+void print_help(int TypeFlag=0){
+ std::cerr << std::endl << "ngt - collects n-grams" << std::endl;
+ std::cerr << std::endl << "USAGE:" << std::endl;
+ std::cerr << " ngt -i=<inputfile> [options]" << std::endl;
+ std::cerr << std::endl << "OPTIONS:" << std::endl;
+
+ FullPrintParams(TypeFlag, 0, 1, stderr);
+}
+
+void usage(const char *msg = 0)
+{
+ if (msg){
+ std::cerr << msg << std::endl;
+ }
+ else{
+ print_help();
+ }
+}
+
+int main(int argc, char **argv)
+{
+ char *inp=NULL;
+ char *out=NULL;
+ char *dic=NULL; // dictionary filename
+ char *subdic=NULL; // subdictionary filename
+ char *filterdict=NULL; // subdictionary filename
+ char *filtertable=NULL; // ngramtable filename
+ char *iknfile=NULL; // filename to save IKN statistics
+ double filter_hit_rate=1.0; // minimum hit rate of filter
+ char *aug=NULL; // augmentation data
+ char *hmask=NULL; // historymask
+ bool inputgoogleformat=false; //reads ngrams in Google format
+ bool outputgoogleformat=false; //print ngrams in Google format
+ bool outputredisformat=false; //print ngrams in Redis format
+ int ngsz=0; // n-gram default size
+ int dstco=0; // compute distance co-occurrences
+ bool bin=false;
+ bool ss=false; //generate single table
+ bool LMflag=false; //work with LM table
+ bool saveeach=false; //save all n-gram orders
+ int inplen=0; //input length for mask generation
+ bool tlm=false; //test lm table
+ char* ftlm=NULL; //file to test LM table
+
+ bool memuse=false;
+ bool help=false;
+
+
+ DeclareParams((char*)
+ "Dictionary", CMDSTRINGTYPE|CMDMSG, &dic, "dictionary filename",
+ "d", CMDSTRINGTYPE|CMDMSG, &dic, "dictionary filename",
+
+ "NgramSize", CMDSUBRANGETYPE|CMDMSG, &ngsz, 1, MAX_NGRAM, "n-gram default size; default is 0",
+ "n", CMDSUBRANGETYPE|CMDMSG, &ngsz, 1, MAX_NGRAM, "n-gram default size; default is 0",
+ "InputFile", CMDSTRINGTYPE|CMDMSG, &inp, "input file",
+ "i", CMDSTRINGTYPE|CMDMSG, &inp, "input file",
+ "OutputFile", CMDSTRINGTYPE|CMDMSG, &out, "output file",
+ "o", CMDSTRINGTYPE|CMDMSG, &out, "output file",
+ "InputGoogleFormat", CMDBOOLTYPE|CMDMSG, &inputgoogleformat, "the input file contains data in the n-gram Google format; default is false",
+ "gooinp", CMDBOOLTYPE|CMDMSG, &inputgoogleformat, "the input file contains data in the n-gram Google format; default is false",
+ "OutputGoogleFormat", CMDBOOLTYPE|CMDMSG, &outputgoogleformat, "the output file contains data in the n-gram Google format; default is false",
+ "gooout", CMDBOOLTYPE|CMDMSG, &outputgoogleformat, "the output file contains data in the n-gram Google format; default is false",
+ "OutputRedisFormat", CMDBOOLTYPE|CMDMSG, &outputredisformat, "as Goolge format plus corresponding CRC.16 hash values; default is false",
+ "redisout", CMDBOOLTYPE|CMDMSG, &outputredisformat, "as Goolge format plus corresponding CRC.16 hash values; default is false",
+ "SaveEach", CMDBOOLTYPE|CMDMSG, &saveeach, "save all ngram orders; default is false",
+ "saveeach", CMDBOOLTYPE|CMDMSG, &saveeach, "save all ngram orders; default is false",
+ "SaveBinaryTable", CMDBOOLTYPE|CMDMSG, &bin, "saves into binary format; default is false",
+ "b", CMDBOOLTYPE|CMDMSG, &bin, "saves into binary format; default is false",
+ "LmTable", CMDBOOLTYPE|CMDMSG, &LMflag, "works with LM table; default is false",
+ "lm", CMDBOOLTYPE|CMDMSG, &LMflag, "works with LM table; default is false",
+ "DistCo", CMDINTTYPE|CMDMSG, &dstco, "computes distance co-occurrences at the specified distance; default is 0",
+ "dc", CMDINTTYPE|CMDMSG, &dstco, "computes distance co-occurrences at the specified distance; default is 0",
+ "AugmentFile", CMDSTRINGTYPE|CMDMSG, &aug, "augmentation data",
+ "aug", CMDSTRINGTYPE|CMDMSG, &aug, "augmentation data",
+ "SaveSingle", CMDBOOLTYPE|CMDMSG, &ss, "generates single table; default is false",
+ "ss", CMDBOOLTYPE|CMDMSG, &ss, "generates single table; default is false",
+ "SubDict", CMDSTRINGTYPE|CMDMSG, &subdic, "subdictionary",
+ "sd", CMDSTRINGTYPE|CMDMSG, &subdic, "subdictionary",
+ "FilterDict", CMDSTRINGTYPE|CMDMSG, &filterdict, "filter dictionary",
+ "fd", CMDSTRINGTYPE|CMDMSG, &filterdict, "filter dictionary",
+ "ConvDict", CMDSTRINGTYPE|CMDMSG, &subdic, "subdictionary",
+ "cd", CMDSTRINGTYPE|CMDMSG, &subdic, "subdictionary",
+ "FilterTable", CMDSTRINGTYPE|CMDMSG, &filtertable, "ngramtable filename",
+ "ftr", CMDDOUBLETYPE|CMDMSG, &filter_hit_rate, "ngramtable filename",
+ "FilterTableRate", CMDDOUBLETYPE|CMDMSG, &filter_hit_rate, "minimum hit rate of filter; default is 1.0",
+ "ft", CMDSTRINGTYPE|CMDMSG, &filtertable, "minimum hit rate of filter; default is 1.0",
+ "HistoMask",CMDSTRINGTYPE|CMDMSG, &hmask, "history mask",
+ "hm",CMDSTRINGTYPE|CMDMSG, &hmask, "history mask",
+ "InpLen",CMDINTTYPE|CMDMSG, &inplen, "input length for mask generation; default is 0",
+ "il",CMDINTTYPE|CMDMSG, &inplen, "input length for mask generation; default is 0",
+ "tlm", CMDBOOLTYPE|CMDMSG, &tlm, "test LM table; default is false",
+ "ftlm", CMDSTRINGTYPE|CMDMSG, &ftlm, "file to test LM table",
+ "memuse", CMDBOOLTYPE|CMDMSG, &memuse, "default is false",
+ "iknstat", CMDSTRINGTYPE|CMDMSG, &iknfile, "filename to save IKN statistics",
+
+ "Help", CMDBOOLTYPE|CMDMSG, &help, "print this help",
+ "h", CMDBOOLTYPE|CMDMSG, &help, "print this help",
+
+ (char *)NULL
+ );
+
+
+ if (argc == 1){
+ usage();
+ exit_error(IRSTLM_NO_ERROR);
+ }
+
+ GetParams(&argc, &argv, (char*) NULL);
+
+ if (help){
+ usage();
+ exit_error(IRSTLM_NO_ERROR);
+ }
+
+ if (inp==NULL) {
+ usage();
+ exit_error(IRSTLM_ERROR_DATA,"Warning: no input file specified");
+ };
+
+ if (out==NULL) {
+ cerr << "Warning: no output file specified!\n";
+ }
+
+ TABLETYPE table_type=COUNT;
+
+ if (LMflag) {
+ cerr << "Working with LM table\n";
+ table_type=LEAFPROB;
+ }
+
+
+ // check word order of subdictionary
+
+ if (filtertable) {
+
+ {
+ ngramtable ngt(filtertable,1,NULL,NULL,NULL,0,0,NULL,0,table_type);
+ mfstream inpstream(inp,ios::in); //google input table
+ mfstream outstream(out,ios::out); //google output table
+
+ cerr << "Filtering table " << inp << " assumed to be in Google Format with size " << ngsz << "\n";
+ cerr << "with table " << filtertable << " of size " << ngt.maxlevel() << "\n";
+ cerr << "with hit rate " << filter_hit_rate << "\n";
+
+ //order of filter table must be smaller than that of input n-grams
+ MY_ASSERT(ngt.maxlevel() <= ngsz);
+
+ //read input googletable of ngrams of size ngsz
+ //output entries made of at least X% n-grams contained in filtertable
+ //<unk> words are not accepted
+
+ ngram ng(ngt.dict), ng2(ng.dict);
+ double hits=0;
+ double maxhits=(double)(ngsz-ngt.maxlevel()+1);
+
+ long c=0;
+ while(inpstream >> ng) {
+
+ if (ng.size>= ngt.maxlevel()) {
+ //need to make a copy
+ ng2=ng;
+ ng2.size=ngt.maxlevel();
+ //cerr << "check if " << ng2 << " is contained: ";
+ hits+=(ngt.get(ng2)?1:0);
+ }
+
+ if (ng.size==ngsz) {
+ if (!(++c % 1000000)) cerr << ".";
+ //cerr << ng << " -> " << is_included << "\n";
+ //you reached the last word before freq
+ inpstream >> ng.freq;
+ //consistency check of n-gram
+ if (((hits/maxhits)>=filter_hit_rate) &&
+ (!ng.containsWord(ngt.dict->OOV(),ng.size))
+ )
+ outstream << ng << "\n";
+ hits=0;
+ ng.size=0;
+ }
+ }
+
+ outstream.flush();
+ inpstream.flush();
+ }
+
+ exit_error(IRSTLM_NO_ERROR);
+ }
+
+
+
+ //ngramtable* ngt=new ngramtable(inp,ngsz,NULL,dic,dstco,hmask,inplen,table_type);
+ ngramtable* ngt=new ngramtable(inp,ngsz,NULL,NULL,filterdict,inputgoogleformat,dstco,hmask,inplen,table_type);
+
+ if (aug) {
+ ngt->dict->incflag(1);
+ // ngramtable ngt2(aug,ngsz,isym,NULL,0,NULL,0,table_type);
+ ngramtable ngt2(aug,ngsz,NULL,NULL,NULL,0,0,NULL,0,table_type);
+ ngt->augment(&ngt2);
+ ngt->dict->incflag(0);
+ }
+
+
+ if (subdic) {
+
+ ngramtable *ngt2=new ngramtable(NULL,ngsz,NULL,NULL,NULL,0,0,NULL,0,table_type);
+
+ // enforce the subdict to follow the same word order of the main dictionary
+ dictionary tmpdict(subdic);
+ ngt2->dict->incflag(1);
+ for (int j=0; j<ngt->dict->size(); j++) {
+ if (tmpdict.encode(ngt->dict->decode(j)) != tmpdict.oovcode()) {
+ ngt2->dict->encode(ngt->dict->decode(j));
+ }
+ }
+ ngt2->dict->incflag(0);
+
+ ngt2->dict->cleanfreq();
+
+ //possibly include standard symbols
+ if (ngt->dict->encode(ngt->dict->EoS())!=ngt->dict->oovcode()) {
+ ngt2->dict->incflag(1);
+ ngt2->dict->encode(ngt2->dict->EoS());
+ ngt2->dict->incflag(0);
+ }
+ if (ngt->dict->encode(ngt->dict->BoS())!=ngt->dict->oovcode()) {
+ ngt2->dict->incflag(1);
+ ngt2->dict->encode(ngt2->dict->BoS());
+ ngt2->dict->incflag(0);
+ }
+
+
+ ngram ng(ngt->dict);
+ ngram ng2(ngt2->dict);
+
+ ngt->scan(ng,INIT,ngsz);
+ long c=0;
+ while (ngt->scan(ng,CONT,ngsz)) {
+ ng2.trans(ng);
+ ngt2->put(ng2);
+ if (!(++c % 1000000)) cerr << ".";
+ }
+
+ //makes ngt2 aware of oov code
+ int oov=ngt2->dict->getcode(ngt2->dict->OOV());
+ if(oov>=0) ngt2->dict->oovcode(oov);
+
+ for (int j=0; j<ngt->dict->size(); j++) {
+ ngt2->dict->incfreq(ngt2->dict->encode(ngt->dict->decode(j)),
+ ngt->dict->freq(j));
+ }
+
+ cerr <<" oov: " << ngt2->dict->freq(ngt2->dict->oovcode()) << "\n";
+
+ delete ngt;
+ ngt=ngt2;
+
+ }
+
+ if (ngsz < ngt->maxlevel() && hmask) {
+ cerr << "start projection of ngramtable " << inp
+ << " according to hmask\n";
+
+ int selmask[MAX_NGRAM];
+ memset(selmask, 0, sizeof(int)*MAX_NGRAM);
+
+ //parse hmask
+ selmask[0]=1;
+ int i=1;
+ for (size_t c=0; c<strlen(hmask); c++) {
+ cerr << hmask[c] << "\n";
+ if (hmask[c] == '1'){
+ selmask[i]=c+2;
+ i++;
+ }
+ }
+
+ if (i!= ngsz) {
+ std::stringstream ss_msg;
+ ss_msg << "wrong mask: 1 bits=" << i << " maxlev=" << ngsz;
+ exit_error(IRSTLM_ERROR_DATA, ss_msg.str());
+ }
+
+ if (selmask[ngsz-1] > ngt->maxlevel()) {
+ std::stringstream ss_msg;
+ ss_msg << "wrong mask: farest bits=" << selmask[ngsz-1]
+ << " maxlev=" << ngt->maxlevel() << "\n";
+ exit_error(IRSTLM_ERROR_DATA, ss_msg.str());
+ }
+
+ //ngramtable* ngt2=new ngramtable(NULL,ngsz,NULL,NULL,0,NULL,0,table_type);
+ ngramtable* ngt2=new ngramtable(NULL,ngsz,NULL,NULL,NULL,0,0,NULL,0,table_type);
+
+ ngt2->dict->incflag(1);
+
+ ngram ng(ngt->dict);
+ ngram png(ngt->dict,ngsz);
+ ngram ng2(ngt2->dict,ngsz);
+
+ ngt->scan(ng,INIT,ngt->maxlevel());
+ long c=0;
+ while (ngt->scan(ng,CONT,ngt->maxlevel())) {
+ //projection
+ for (int j=0; j<ngsz; j++)
+ *png.wordp(j+1)=*ng.wordp(selmask[j]);
+ png.freq=ng.freq;
+ //transfer
+ ng2.trans(png);
+ ngt2->put(ng2);
+ if (!(++c % 1000000)) cerr << ".";
+ }
+
+ char info[100];
+ sprintf(info,"hm%s",hmask);
+ ngt2->ngtype(info);
+
+ //makes ngt2 aware of oov code
+ int oov=ngt2->dict->getcode(ngt2->dict->OOV());
+ if(oov>=0) ngt2->dict->oovcode(oov);
+
+ for (int j=0; j<ngt->dict->size(); j++) {
+ ngt2->dict->incfreq(ngt2->dict->encode(ngt->dict->decode(j)),
+ ngt->dict->freq(j));
+ }
+
+ cerr <<" oov: " << ngt2->dict->freq(ngt2->dict->oovcode()) << "\n";
+
+ delete ngt;
+ ngt=ngt2;
+ }
+
+
+ if (tlm && table_type==LEAFPROB) {
+ ngram ng(ngt->dict);
+ cout.setf(ios::scientific);
+
+ cout << "> ";
+ while(cin >> ng) {
+ ngt->bo_state(0);
+ if (ng.size>=ngsz) {
+ cout << ng << " p= " << log(ngt->prob(ng));
+ cout << " bo= " << ngt->bo_state() << "\n";
+ } else
+ cout << ng << " p= NULL\n";
+
+ cout << "> ";
+ }
+
+ }
+
+
+ if (ftlm && table_type==LEAFPROB) {
+
+ ngram ng(ngt->dict);
+ cout.setf(ios::fixed);
+ cout.precision(2);
+
+ mfstream inptxt(ftlm,ios::in);
+ int Nbo=0,Nw=0,Noov=0;
+ float logPr=0,PP=0,PPwp=0;
+
+ int bos=ng.dict->encode(ng.dict->BoS());
+
+ while(inptxt >> ng) {
+
+ // reset ngram at begin of sentence
+ if (*ng.wordp(1)==bos) {
+ ng.size=1;
+ continue;
+ }
+
+ ngt->bo_state(0);
+ if (ng.size>=1) {
+ logPr+=log(ngt->prob(ng));
+ if (*ng.wordp(1) == ngt->dict->oovcode())
+ Noov++;
+
+ Nw++;
+ if (ngt->bo_state()) Nbo++;
+ }
+ }
+
+ PP=exp(-logPr/Nw);
+ PPwp= PP * exp(Noov * log(10000000.0-ngt->dict->size())/Nw);
+
+ cout << "%%% NGT TEST OF SMT LM\n";
+ cout << "%% LM=" << inp << " SIZE="<< ngt->maxlevel();
+ cout << " TestFile="<< ftlm << "\n";
+ cout << "%% OOV PENALTY = 1/" << 10000000.0-ngt->dict->size() << "\n";
+
+
+ cout << "%% Nw=" << Nw << " PP=" << PP << " PPwp=" << PPwp
+ << " Nbo=" << Nbo << " Noov=" << Noov
+ << " OOV=" << (float)Noov/Nw * 100.0 << "%\n";
+
+ }
+
+
+ if (memuse) ngt->stat(0);
+
+
+ if (iknfile) { //compute and save statistics of Improved Kneser Ney smoothing
+
+ ngram ng(ngt->dict);
+ int n1,n2,n3,n4;
+ int unover3=0;
+ mfstream iknstat(iknfile,ios::out); //output of ikn statistics
+
+ for (int l=1; l<=ngt->maxlevel(); l++) {
+
+ cerr << "level " << l << "\n";
+ iknstat << "level: " << l << " ";
+
+ cerr << "computing statistics\n";
+
+ n1=0;
+ n2=0;
+ n3=0,n4=0;
+
+ ngt->scan(ng,INIT,l);
+
+ while(ngt->scan(ng,CONT,l)) {
+
+ //skip ngrams containing _OOV
+ if (l>1 && ng.containsWord(ngt->dict->OOV(),l)) {
+ //cerr << "skp ngram" << ng << "\n";
+ continue;
+ }
+
+ //skip n-grams containing </s> in context
+ if (l>1 && ng.containsWord(ngt->dict->EoS(),l-1)) {
+ //cerr << "skp ngram" << ng << "\n";
+ continue;
+ }
+
+ //skip 1-grams containing <s>
+ if (l==1 && ng.containsWord(ngt->dict->BoS(),l)) {
+ //cerr << "skp ngram" << ng << "\n";
+ continue;
+ }
+
+ if (ng.freq==1) n1++;
+ else if (ng.freq==2) n2++;
+ else if (ng.freq==3) n3++;
+ else if (ng.freq==4) n4++;
+ if (l==1 && ng.freq >=3) unover3++;
+
+ }
+
+
+ cerr << " n1: " << n1 << " n2: " << n2 << " n3: " << n3 << " n4: " << n4 << "\n";
+ iknstat << " n1: " << n1 << " n2: " << n2 << " n3: " << n3 << " n4: " << n4 << " unover3: " << unover3 << "\n";
+
+ }
+
+ }
+
+ if (out){
+ if (bin) ngt->savebin(out,ngsz);
+ else if (outputredisformat) ngt->savetxt(out,ngsz,true,true,
+ 1);
+ else if (outputgoogleformat) ngt->savetxt(out,ngsz,true,false);
+ else ngt->savetxt(out,ngsz,false,false);
+ }
+}
+
diff --git a/src/normcache.cpp b/src/normcache.cpp
new file mode 100644
index 0000000..f64c167
--- /dev/null
+++ b/src/normcache.cpp
@@ -0,0 +1,123 @@
+/******************************************************************************
+IrstLM: IRST Language Model Toolkit
+Copyright (C) 2006 Marcello Federico, ITC-irst Trento, Italy
+
+This library is free software; you can redistribute it and/or
+modify it under the terms of the GNU Lesser General Public
+License as published by the Free Software Foundation; either
+version 2.1 of the License, or (at your option) any later version.
+
+This library is distributed in the hope that it will be useful,
+but WITHOUT ANY WARRANTY; without even the implied warranty of
+MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+Lesser General Public License for more details.
+
+You should have received a copy of the GNU Lesser General Public
+License along with this library; if not, write to the Free Software
+Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+
+******************************************************************************/
+
+
+#include "mfstream.h"
+#include "mempool.h"
+#include "htable.h"
+#include "dictionary.h"
+#include "n_gram.h"
+#include "ngramtable.h"
+#include "normcache.h"
+
+using namespace std;
+
+// Normalization factors cache
+
+normcache::normcache(dictionary* d)
+{
+ dict=d;
+ //trigram and bigram normalization cache
+
+ //ngt=new ngramtable(NULL,2,NULL,NULL,0,0,NULL,0,LEAFPROB);
+ ngt=new ngramtable(NULL,2,NULL,NULL,NULL,0,0,NULL,0,LEAFPROB);
+
+ maxcache[0]=d->size();//unigram cache
+ maxcache[1]=d->size();//bigram cache
+
+ cache[0]=new double[maxcache[0]];
+ cache[1]=new double[maxcache[1]];
+
+ for (int i=0; i<d->size(); i++)
+ cache[0][i]=cache[1][i]=0.0;
+
+ cachesize[0]=cachesize[1]=0;
+ hit=miss=0;
+}
+
+void normcache::expand(int n)
+{
+
+ int step=100000;
+ cerr << "Expanding cache ...\n";
+ double *newcache=new double[maxcache[n]+step];
+ memcpy(newcache,cache[n],sizeof(double)*maxcache[n]);
+ delete [] cache[n];
+ cache[n]=newcache;
+ for (int i=0; i<step; i++)
+ cache[n][maxcache[n]+i]=0;
+ maxcache[n]+=step;
+};
+
+
+double normcache::get(ngram ng,int size,double& value)
+{
+
+ if (size==2) {
+ if (*ng.wordp(2) < cachesize[0])
+ return value=cache[0][*ng.wordp(2)];
+ else
+ return value=0;
+ } else if (size==3) {
+ if (ngt->get(ng,size,size-1)) {
+ hit++;
+ // cerr << "hit " << ng << "\n";
+ return value=cache[1][ng.freq];
+ } else {
+ miss++;
+ return value=0;
+ }
+ }
+ return 0;
+}
+
+double normcache::put(ngram ng,int size,double value)
+{
+
+ if (size==2) {
+ if (*ng.wordp(2)>= maxcache[0]) expand(0);
+ cache[0][*ng.wordp(2)]=value;
+ cachesize[0]++;
+ return value;
+ } else if (size==3) {
+ if (ngt->get(ng,size,size-1))
+ return cache[1][ng.freq]=value;
+ else {
+ ngram histo(dict,2);
+ *histo.wordp(1)=*ng.wordp(2);
+ *histo.wordp(2)=*ng.wordp(3);
+ histo.freq=cachesize[1]++;
+ if (cachesize[1]==maxcache[1]) expand(1);
+ ngt->put(histo);
+ return cache[1][histo.freq]=value;
+ }
+ }
+ return 0;
+}
+
+void normcache::stat()
+{
+ std::cout << "misses " << miss << ", hits " << hit << "\n";
+}
+
+
+
+
+
diff --git a/src/normcache.h b/src/normcache.h
new file mode 100644
index 0000000..c789092
--- /dev/null
+++ b/src/normcache.h
@@ -0,0 +1,53 @@
+/******************************************************************************
+IrstLM: IRST Language Model Toolkit
+Copyright (C) 2006 Marcello Federico, ITC-irst Trento, Italy
+
+This library is free software; you can redistribute it and/or
+modify it under the terms of the GNU Lesser General Public
+License as published by the Free Software Foundation; either
+version 2.1 of the License, or (at your option) any later version.
+
+This library is distributed in the hope that it will be useful,
+but WITHOUT ANY WARRANTY; without even the implied warranty of
+MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+Lesser General Public License for more details.
+
+You should have received a copy of the GNU Lesser General Public
+License along with this library; if not, write to the Free Software
+Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+
+******************************************************************************/
+
+#ifndef MF_NORMCACHE_H
+#define MF_NORMCACHE_H
+
+#include "dictionary.h"
+#include "ngramtable.h"
+
+// Normalization factors cache
+
+class normcache
+{
+ dictionary* dict;
+ ngramtable *ngt;
+ double* cache[2];
+ int cachesize[2];
+ int maxcache[2];
+ int hit;
+ int miss;
+
+public:
+ normcache(dictionary* d);
+ ~normcache() {
+ delete [] cache[0];
+ delete [] cache[1];
+ delete ngt;
+ }
+
+ void expand(int i);
+ double get(ngram ng,int size,double& value);
+ double put(ngram ng,int size,double value);
+ void stat();
+};
+#endif
+
diff --git a/src/plsa.cpp b/src/plsa.cpp
new file mode 100755
index 0000000..5d3c8b5
--- /dev/null
+++ b/src/plsa.cpp
@@ -0,0 +1,250 @@
+/******************************************************************************
+ IrstLM: IRST Language Model Toolkit, compile LM
+ Copyright (C) 2006 Marcello Federico, ITC-irst Trento, Italy
+
+ This library is free software; you can redistribute it and/or
+ modify it under the terms of the GNU Lesser General Public
+ License as published by the Free Software Foundation; either
+ version 2.1 of the License, or (at your option) any later version.
+
+ This library is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+ Lesser General Public License for more details.
+
+ You should have received a copy of the GNU Lesser General Public
+ License along with this library; if not, write to the Free Software
+ Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+
+ ******************************************************************************/
+
+
+#include <iostream>
+#include "cmd.h"
+#include <pthread.h>
+#include "thpool.h"
+#include "util.h"
+#include "mfstream.h"
+#include "mempool.h"
+#include "htable.h"
+#include "dictionary.h"
+#include "n_gram.h"
+#include "ngramtable.h"
+#include "doc.h"
+#include "cplsa.h"
+
+using namespace std;
+using namespace irstlm;
+
+void print_help(int TypeFlag=0){
+ std::cerr << std::endl << "plsa - probabilistic latent semantic analysis modeling" << std::endl;
+ std::cerr << std::endl << "USAGE:" << std::endl;
+ std::cerr << " plsa -tr|te=<text> -m=<model> -t=<n> [options]" << std::endl;
+ std::cerr << std::endl << "DESCRIPTION:" << std::endl;
+ std::cerr << " Train a PLSA model from a corpus and test it to infer topic or word " << std::endl;
+ std::cerr << " distributions from other texts." << std::endl;
+ std::cerr << " Notice: multithreading is available both for training and inference." << std::endl;
+
+ std::cerr << std::endl << "OPTIONS:" << std::endl;
+
+
+ FullPrintParams(TypeFlag, 0, 1, stderr);
+
+ std::cerr << std::endl << "EXAMPLES:" << std::endl;
+ std::cerr <<" (1) plsa -tr=<text> -t=<n> -m=<model> " << std::endl;
+ std::cerr <<" Train a PLSA model <model> with <n> topics on text <text> " << std::endl;
+ std::cerr <<" Example of <text> content:" << std::endl;
+ std::cerr <<" 3" << std::endl;
+ std::cerr <<" <d> hello world ! </d>" << std::endl;
+ std::cerr <<" <d> good morning good afternoon </d>" << std::endl;
+ std::cerr <<" <d> welcome aboard </d>" << std::endl;
+ std::cerr <<" (2) plsa -m=<model> -te=<text> -tf=<features>" << std::endl;
+ std::cerr <<" Infer topic distribution with model <model> for each doc in <text>" << std::endl;
+ std::cerr <<" (3) plsa -m=<model> -te=<text> -wf=<features>" << std::endl;
+ std::cerr <<" Infer word distribution with model <model> for each doc in <text>" << std::endl;
+ std::cerr << std::endl;
+}
+
+void usage(const char *msg = 0)
+{
+ if (msg){
+ std::cerr << msg << std::endl;
+ }
+ else{
+ print_help();
+ }
+}
+
+int main(int argc, char **argv){
+ char *dictfile=NULL;
+ char *trainfile=NULL;
+ char *testfile=NULL;
+ char *topicfeaturefile=NULL;
+ char *wordfeaturefile=NULL;
+ char *modelfile=NULL;
+ char *tmpdir = getenv("TMP");
+ char *txtfile=NULL;
+ bool forcemodel=false;
+
+ int topics=0; //number of topics
+ int specialtopic=0; //special topic: first st dict words
+ int iterations=10; //number of EM iterations to run
+ int threads=1; //current EM iteration for multi-thread training
+ bool help=false;
+ bool memorymap=true;
+ int prunethreshold=3;
+ int topwords=20;
+ DeclareParams((char*)
+
+ "Train", CMDSTRINGTYPE|CMDMSG, &trainfile, "<fname> : training text collection ",
+ "tr", CMDSTRINGTYPE|CMDMSG, &trainfile, "<fname> : training text collection ",
+
+ "Model", CMDSTRINGTYPE|CMDMSG, &modelfile, "<fname> : model file",
+ "m", CMDSTRINGTYPE|CMDMSG, &modelfile, "<fname> : model file",
+
+ "TopWordsFile", CMDSTRINGTYPE|CMDMSG, &txtfile, "<fname> to write top words per topic",
+ "twf", CMDSTRINGTYPE|CMDMSG, &txtfile, "<fname> to write top words per topic",
+
+ "PruneFreq", CMDINTTYPE|CMDMSG, &prunethreshold, "<count>: prune words with freq <= count (default 3)",
+ "pf", CMDINTTYPE|CMDMSG, &prunethreshold, "<count>: <count>: prune words with freq <= count (default 3)",
+
+ "TopWordsNum", CMDINTTYPE|CMDMSG, &topwords, "<count>: number of top words per topic ",
+ "twn", CMDINTTYPE|CMDMSG, &topwords, "<count>: number of top words per topic",
+
+ "Test", CMDSTRINGTYPE|CMDMSG, &testfile, "<fname> : inference text collection file",
+ "te", CMDSTRINGTYPE|CMDMSG, &testfile, "<fname> : inference text collection file",
+
+ "WordFeatures", CMDSTRINGTYPE|CMDMSG, &wordfeaturefile, "<fname> : unigram feature file",
+ "wf", CMDSTRINGTYPE|CMDMSG, &wordfeaturefile,"<fname> : unigram feature file",
+
+ "TopicFeatures", CMDSTRINGTYPE|CMDMSG, &topicfeaturefile, "<fname> : topic feature file",
+ "tf", CMDSTRINGTYPE|CMDMSG, &topicfeaturefile, "<fname> : topic feature file",
+
+ "Topics", CMDINTTYPE|CMDMSG, &topics, "<count> : number of topics (default 0)",
+ "t", CMDINTTYPE|CMDMSG, &topics,"<count> : number of topics (default 0)",
+
+ "SpecialTopic", CMDINTTYPE|CMDMSG, &specialtopic, "<count> : put top-<count> frequent words in a special topic (default 0)",
+ "st", CMDINTTYPE|CMDMSG, &specialtopic, "<count> : put top-<count> frequent words in a special topic (default 0)",
+
+ "Iterations", CMDINTTYPE|CMDMSG, &iterations, "<count> : training/inference iterations (default 10)",
+ "it", CMDINTTYPE|CMDMSG, &iterations, "<count> : training/inference iterations (default 10)",
+
+ "Threads", CMDINTTYPE|CMDMSG, &threads, "<count>: number of threads (default 2)",
+ "th", CMDINTTYPE|CMDMSG, &threads, "<count>: number of threads (default 2)",
+
+ "ForceModel", CMDBOOLTYPE|CMDMSG, &forcemodel, "<bool>: force to use existing model for training",
+ "fm", CMDBOOLTYPE|CMDMSG, &forcemodel, "<bool>: force to use existing model for training",
+
+ "MemoryMap", CMDBOOLTYPE|CMDMSG, &memorymap, "<bool>: use memory mapping (default true)",
+ "mm", CMDBOOLTYPE|CMDMSG, &memorymap, "<bool>: use memory mapping (default true)",
+
+ "Dictionary", CMDSTRINGTYPE|CMDMSG, &dictfile, "<fname> : specify a training dictionary (optional)",
+ "d", CMDSTRINGTYPE|CMDMSG, &dictfile, "<fname> : specify training a dictionary (optional)",
+
+ "TmpDir", CMDSTRINGTYPE|CMDMSG, &tmpdir, "<folder>: tmp directory for memory map (default /tmp)",
+ "tmp", CMDSTRINGTYPE|CMDMSG, &tmpdir, "<folder>: tmp directory for memory map (default /tmp )",
+
+
+ "Help", CMDBOOLTYPE|CMDMSG, &help, "print this help",
+ "h", CMDBOOLTYPE|CMDMSG, &help, "print this help",
+
+ (char *)NULL
+ );
+
+ if (argc == 1){
+ usage();
+ exit_error(IRSTLM_NO_ERROR);
+ }
+
+ GetParams(&argc, &argv, (char*) NULL);
+
+ if (help){
+ usage();
+ exit_error(IRSTLM_NO_ERROR);
+ }
+
+
+ if (trainfile && ( !topics || !modelfile )) {
+ usage();
+ exit_error(IRSTLM_ERROR_DATA,"Missing training parameters");
+ }
+
+ if (testfile && (!modelfile || !(topicfeaturefile || wordfeaturefile))) {
+ usage();
+ exit_error(IRSTLM_ERROR_DATA,"Missing inference parameters");
+ }
+
+ dictionary *dict=NULL;
+
+ //Training phase
+ //test if model is readable
+ bool testmodel=false;
+ FILE* f;if ((f=fopen(modelfile,"r"))!=NULL){fclose(f);testmodel=true;}
+
+ if (trainfile){
+ if (testmodel){
+ if (!forcemodel)
+ //training with pretrained model: no need of dictionary
+ exit_error(IRSTLM_ERROR_DATA,"Use -ForceModel=y option to use and update an existing model.");
+ }
+ else{//training with empty model and no dictionary: dictionary must be first extracted
+ if (!dictfile){
+
+ // exit_error(IRSTLM_ERROR_DATA,"Missing dictionary. Provide a dictionary with option -d.");
+
+ cerr << "Extracting dictionary from training data (word with freq>=" << prunethreshold << ")\n";
+ dict=new dictionary(NULL,10000);
+ dict->generate(trainfile,true);
+
+ dictionary *sortd=new dictionary(dict,true,prunethreshold);
+ sortd->sort();
+ delete dict;
+ dict=sortd;
+
+ }
+ else
+ dict=new dictionary(dictfile,10000);
+ dict->encode(dict->OOV());
+ }
+
+ plsa tc(dict,topics,tmpdir,threads,memorymap);
+ tc.train(trainfile,modelfile,iterations,0.5,specialtopic);
+ if (dict!=NULL) delete dict;
+ }
+
+ //Training phase
+ //test if model is readable: notice test could be executed after training
+
+ testmodel=false;
+ if ((f=fopen(modelfile,"r"))!=NULL){fclose(f);testmodel=true;}
+
+ if (testfile){
+ if (!testmodel)
+ exit_error(IRSTLM_ERROR_DATA,"Cannot read model file to run test inference.");
+ if (dictfile) cerr << "Will rely on model dictionary.";
+
+ dict=NULL;
+ plsa tc(dict,topics,tmpdir,threads,memorymap);
+ tc.inference(testfile,modelfile,iterations,topicfeaturefile,wordfeaturefile);
+ if (dict!=NULL) delete dict;
+ }
+
+
+ //save/convert model in text format
+
+ if (txtfile){
+ if (!testmodel)
+ exit_error(IRSTLM_ERROR_DATA,"Cannot open model to be printed in readable format.");
+
+ dict=NULL;
+ plsa tc(dict,topics,tmpdir,threads,memorymap);
+ tc.initW(modelfile,1,0);
+ tc.saveWtxt(txtfile,topwords);
+ tc.freeW();
+ }
+
+ exit_error(IRSTLM_NO_ERROR);
+}
+
+
+
diff --git a/src/prune-lm.cpp b/src/prune-lm.cpp
new file mode 100644
index 0000000..6cdd007
--- /dev/null
+++ b/src/prune-lm.cpp
@@ -0,0 +1,175 @@
+// $Id: prune-lm.cpp 27 2010-05-03 14:33:51Z nicolabertoldi $
+
+/******************************************************************************
+ IrstLM: IRST Language Model Toolkit, prune LM
+ Copyright (C) 2008 Fabio Brugnara, FBK-irst Trento, Italy
+
+ This library is free software; you can redistribute it and/or
+ modify it under the terms of the GNU Lesser General Public
+ License as published by the Free Software Foundation; either
+ version 2.1 of the License, or (at your option) any later version.
+
+ This library is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+ Lesser General Public License for more details.
+
+ You should have received a copy of the GNU Lesser General Public
+ License along with this library; if not, write to the Free Software
+ Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+
+******************************************************************************/
+
+
+#include <iostream>
+#include <sstream>
+#include <fstream>
+#include <vector>
+#include <string>
+#include <stdlib.h>
+#include "cmd.h"
+#include "util.h"
+#include "math.h"
+#include "lmtable.h"
+
+/********************************/
+using namespace std;
+using namespace irstlm;
+
+void print_help(int TypeFlag=0){
+ std::cerr << std::endl << "prune-lm - prunes language models" << std::endl;
+ std::cerr << std::endl << "USAGE:" << std::endl;
+ std::cerr << " prune-lm [options] <inputfile> [<outputfile>]" << std::endl;
+ std::cerr << std::endl << "DESCRIPTION:" << std::endl;
+ std::cerr << " prune-lm reads a LM in either ARPA or compiled format and" << std::endl;
+ std::cerr << " prunes out n-grams (n=2,3,..) for which backing-off to the" << std::endl;
+ std::cerr << " lower order n-gram results in a small difference in probability." << std::endl;
+ std::cerr << " The pruned LM is saved in ARPA format" << std::endl;
+ std::cerr << std::endl << "OPTIONS:" << std::endl;
+
+ FullPrintParams(TypeFlag, 0, 1, stderr);
+}
+
+void usage(const char *msg = 0)
+{
+ if (msg){
+ std::cerr << msg << std::endl;
+ }
+ if (!msg){
+ print_help();
+ }
+}
+
+void s2t(string cps, float *thr)
+{
+ int i;
+ char *s=strdup(cps.c_str());
+ char *tk;
+
+ thr[0]=0;
+ for(i=1,tk=strtok(s, ","); tk; tk=strtok(0, ","),i++) thr[i]=atof(tk);
+ for(; i<MAX_NGRAM; i++) thr[i]=thr[i-1];
+ free(s);
+}
+
+int main(int argc, char **argv)
+{
+ float thr[MAX_NGRAM];
+ char *spthr=NULL;
+ int aflag=0;
+ std::vector<std::string> files;
+
+ bool help=false;
+
+ DeclareParams((char*)
+ "threshold", CMDSTRINGTYPE|CMDMSG, &spthr, "pruning thresholds for 2-grams, 3-grams, 4-grams,...; if less thresholds are specified, the last one is applied to all following n-gram levels; default is 0",
+ "t", CMDSTRINGTYPE|CMDMSG, &spthr, "pruning thresholds for 2-grams, 3-grams, 4-grams,...; if less thresholds are specified, the last one is applied to all following n-gram levels; default is 0",
+
+ "abs", CMDBOOLTYPE|CMDMSG, &aflag, "uses absolute value of weighted difference; default is 0",
+
+ "Help", CMDBOOLTYPE|CMDMSG, &help, "print this help",
+ "h", CMDBOOLTYPE|CMDMSG, &help, "print this help",
+
+ (char *)NULL
+ );
+
+ if (argc == 1){
+ usage();
+ exit_error(IRSTLM_NO_ERROR);
+ }
+
+ int first_file=1;
+ for (int i=1; i < argc; i++) {
+ if (strcmp(argv[i],"-") == 0){ //handles /dev/stdin or /dev/stdout
+ if (first_file == 1){
+ files.push_back("/dev/stdin");
+ }else if (first_file == 2){
+ files.push_back("/dev/stdout");
+ }else{
+ usage("Warning: You can use the value for the input or output file only");
+ }
+ first_file++;
+ }else if(argv[i][0] != '-'){
+ files.push_back(argv[i]);
+ first_file++;
+ }
+ }
+
+
+ GetParams(&argc, &argv, (char*) NULL);
+
+ if (help){
+ usage();
+ exit_error(IRSTLM_NO_ERROR);
+ }
+
+ if (files.size() > 2) {
+ usage();
+ exit_error(IRSTLM_ERROR_DATA,"Too many arguments");
+ }
+
+ if (files.size() < 1) {
+ usage();
+ exit_error(IRSTLM_ERROR_DATA,"Specify a LM file to read from");
+ }
+
+ memset(thr, 0, sizeof(thr));
+ if(spthr != NULL) s2t(spthr, thr);
+ std::string infile = files[0];
+ std::string outfile= "";
+
+ if (files.size() == 1) {
+ outfile=infile;
+
+ //remove path information
+ std::string::size_type p = outfile.rfind('/');
+ if (p != std::string::npos && ((p+1) < outfile.size()))
+ outfile.erase(0,p+1);
+
+ //eventually strip .gz
+ if (outfile.compare(outfile.size()-3,3,".gz")==0)
+ outfile.erase(outfile.size()-3,3);
+
+ outfile+=".plm";
+ } else
+ outfile = files[1];
+
+ lmtable lmt;
+ inputfilestream inp(infile.c_str());
+ if (!inp.good()) {
+ std::stringstream ss_msg;
+ ss_msg << "Failed to open " << infile;
+ exit_error(IRSTLM_ERROR_IO, ss_msg.str());
+ }
+
+ lmt.load(inp,infile.c_str(),outfile.c_str(),0);
+ std::cerr << "pruning LM with thresholds: \n";
+
+ for (int i=1; i<lmt.maxlevel(); i++) std::cerr<< " " << thr[i];
+ std::cerr << "\n";
+ lmt.wdprune((float*)thr, aflag);
+ lmt.savetxt(outfile.c_str());
+
+ exit_error(IRSTLM_NO_ERROR);
+}
+
diff --git a/src/quantize-lm.cpp b/src/quantize-lm.cpp
new file mode 100644
index 0000000..d713349
--- /dev/null
+++ b/src/quantize-lm.cpp
@@ -0,0 +1,512 @@
+// $Id: quantize-lm.cpp 302 2009-08-25 13:04:13Z nicolabertoldi $
+
+/******************************************************************************
+IrstLM: IRST Language Model Toolkit, compile LM
+Copyright (C) 2006 Marcello Federico, ITC-irst Trento, Italy
+
+This library is free software; you can redistribute it and/or
+modify it under the terms of the GNU Lesser General Public
+License as published by the Free Software Foundation; either
+version 2.1 of the License, or (at your option) any later version.
+
+This library is distributed in the hope that it will be useful,
+but WITHOUT ANY WARRANTY; without even the implied warranty of
+MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+Lesser General Public License for more details.
+
+You should have received a copy of the GNU Lesser General Public
+License along with this library; if not, write to the Free Software
+Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+
+******************************************************************************/
+
+
+#include <iostream>
+#include <sstream>
+#include <fstream>
+#include <vector>
+#include <string>
+#include <stdlib.h>
+#include "cmd.h"
+#include "math.h"
+#include "util.h"
+#include "mfstream.h"
+
+
+using namespace std;
+
+//----------------------------------------------------------------------
+// Special type and global variable for the BIN CLUSTERING algorithm
+//
+//
+//----------------------------------------------------------------------
+
+
+typedef struct {
+ float pt;
+ unsigned int idx;
+ unsigned short code;
+} DataItem;
+
+
+int cmpFloatEntry(const void* a,const void* b)
+{
+ if (*(float *)a > *(float*)b)
+ return 1;
+ else if (*(float *)a < *(float *)b)
+ return -1;
+ else
+ return 0;
+}
+
+//----------------------------------------------------------------------
+// Global entry points
+//----------------------------------------------------------------------
+
+int ComputeCluster(int nc, double* cl,unsigned int N,DataItem* Pts);
+
+//----------------------------------------------------------------------
+// Global parameters (some are set in getArgs())
+//----------------------------------------------------------------------
+
+int k = 256; // number of centers
+const int MAXLEV = 11; //maximum n-gram size
+
+//----------------------------------------------------------------------
+// Main program
+//----------------------------------------------------------------------
+
+void print_help(int TypeFlag=0){
+ std::cerr << std::endl << "quantize-lm - quantizes probabilities and back-off weights" << std::endl;
+ std::cerr << std::endl << "USAGE:" << std::endl;
+ std::cerr << " quantize-lm <input-file.lm> [<output-file.qlm> [<tmpfile>]]" << std::endl;
+ std::cerr << std::endl << "DESCRIPTION:" << std::endl;
+ std::cerr << " quantize-lm reads a standard LM file in ARPA format and produces" << std::endl;
+ std::cerr << " a version of it with quantized probabilities and back-off weights"<< std::endl;
+ std::cerr << " that the IRST LM toolkit can compile. Accepts LMs with .gz suffix." << std::endl;
+ std::cerr << " You can specify the output file to be created and also the pathname" << std::endl;
+ std::cerr << " of a temporary file used by the program. As default, the temporary " << std::endl;
+ std::cerr << " file is created in the /tmp directory." << std::endl;
+ std::cerr << " Output file can be written to standard output by using the special name -." << std::endl;
+ std::cerr << std::endl << "OPTIONS:" << std::endl;
+
+ FullPrintParams(TypeFlag, 0, 1, stderr);
+}
+
+void usage(const char *msg = 0)
+{
+ if (msg){
+ std::cerr << msg << std::endl;
+ }
+ else{
+ print_help();
+ }
+}
+
+int main(int argc, char **argv)
+{
+
+ std::vector<std::string> files;
+
+ bool help=false;
+
+ DeclareParams((char*)
+ "Help", CMDBOOLTYPE|CMDMSG, &help, "print this help",
+ "h", CMDBOOLTYPE|CMDMSG, &help, "print this help",
+
+ (char *)NULL
+ );
+
+ if (argc == 1){
+ usage();
+ }
+
+ int first_file=1;
+ for (int i=1; i < argc; i++) {
+ if (strcmp(argv[i],"-") == 0){ //handles /dev/stdin or /dev/stdout
+ if (first_file == 1){
+ files.push_back("/dev/stdin");
+ }else if (first_file == 2){
+ files.push_back("/dev/stdout");
+ }else{
+ usage("Warning: You can use the value for the input and/or output file only");
+ }
+ first_file++;
+ }else if(argv[i][0] != '-'){
+ files.push_back(argv[i]);
+ first_file++;
+ }
+ }
+
+ GetParams(&argc, &argv, (char*) NULL);
+
+ if (help){
+ usage();
+ exit_error(IRSTLM_NO_ERROR);
+ }
+ if (files.size() > 3) {
+ exit_error(IRSTLM_ERROR_DATA,"Too many arguments");
+ }
+
+ if (files.size() < 1) {
+ usage();
+ exit_error(IRSTLM_ERROR_DATA,"Please specify a LM file to read from");
+ }
+
+ std::string infile = files[0];
+ std::string outfile="";
+ std::string tmpfile="";
+
+ if (files.size() == 1) {
+
+ outfile=infile;
+
+ //remove path information
+ std::string::size_type p = outfile.rfind('/');
+ if (p != std::string::npos && ((p+1) < outfile.size()))
+ outfile.erase(0,p+1);
+
+ //eventually strip .gz
+ if (outfile.compare(outfile.size()-3,3,".gz")==0)
+ outfile.erase(outfile.size()-3,3);
+
+ outfile+=".qlm";
+ } else
+ outfile = files[1];
+
+
+ if (files.size()==3) {
+ //create temporary file
+ tmpfile = files[2];
+ mfstream dummy(tmpfile.c_str(),ios::out);
+ dummy.close();
+ } else {
+ //create temporary internal file in /tmp
+ mfstream dummy;
+ createtempfile(dummy,tmpfile,ios::out);
+ dummy.close();
+ }
+
+ std::cerr << "Reading " << infile << "..." << std::endl;
+
+ inputfilestream inp(infile.c_str());
+ if (!inp.good()) {
+ std::stringstream ss_msg;
+ ss_msg << "Failed to open " << infile;;
+ exit_error(IRSTLM_ERROR_IO, ss_msg.str());
+ }
+
+ std::ofstream* out;
+ if (outfile == "-")
+ out = (ofstream *)&std::cout;
+ else {
+ out=new std::ofstream;
+ out->open(outfile.c_str());
+ }
+ if (!out->good()) {
+ std::stringstream ss_msg;
+ ss_msg << "Failed to open " << outfile;
+ exit_error(IRSTLM_ERROR_IO, ss_msg.str());
+ }
+
+ std::cerr << "Writing " << outfile << "..." << std::endl;
+
+ //prepare temporary file to save n-gram blocks for multiple reads
+ //this avoids using seeks which do not work with inputfilestream
+ //it's odd but i need a bidirectional filestream!
+ std::cerr << "Using temporary file " << tmpfile << std::endl;
+ fstream filebuff(tmpfile.c_str(),ios::out|ios::in|ios::binary);
+
+ unsigned int nPts = 0; // actual number of points
+
+ // *** Read ARPA FILE **
+
+ unsigned int numNgrams[MAXLEV + 1]; /* # n-grams for each order */
+ int Order=0,MaxOrder=0;
+ int n=0;
+
+ float logprob,logbow;
+
+ DataItem* dataPts;
+
+ double* centersP=NULL;
+ double* centersB=NULL;
+
+ //maps from point index to code
+ unsigned short* mapP=NULL;
+ unsigned short* mapB=NULL;
+
+ int centers[MAXLEV + 1];
+ streampos iposition;
+
+ for (int i=1; i<=MAXLEV; i++) numNgrams[i]=0;
+ for (int i=1; i<=MAXLEV; i++) centers[i]=k;
+
+ /* all levels 256 centroids; in case read them as parameters */
+
+ char line[MAX_LINE];
+
+ while (inp.getline(line,MAX_LINE)) {
+
+ bool backslash = (line[0] == '\\');
+
+ if (sscanf(line, "ngram %d=%d", &Order, &n) == 2) {
+ numNgrams[Order] = n;
+ MaxOrder=Order;
+ continue;
+ }
+
+ if (!strncmp(line, "\\data\\", 6) || strlen(line)==0)
+ continue;
+
+ if (backslash && sscanf(line, "\\%d-grams", &Order) == 1) {
+
+ // print output header:
+ if (Order == 1) {
+ *out << "qARPA " << MaxOrder;
+ for (int i=1; i<=MaxOrder; i++)
+ *out << " " << centers[i];
+ *out << "\n\n\\data\\\n";
+
+ for (int i=1; i<=MaxOrder; i++)
+ *out << "ngram " << i << "= " << numNgrams[i] << "\n";
+ }
+
+ *out << "\n";
+ *out << line << "\n";
+ cerr << "-- Start processing of " << Order << "-grams\n";
+ MY_ASSERT(Order <= MAXLEV);
+
+ unsigned int N=numNgrams[Order];
+
+ const char* words[MAXLEV+3];
+ dataPts=new DataItem[N]; // allocate data
+
+ //reset tempout file to start writing
+ filebuff.seekg((streampos)0);
+
+ for (nPts=0; nPts<N; nPts++) {
+ inp.getline(line,MAX_LINE);
+ filebuff << line << std::endl;
+ if (!filebuff.good()) {
+ removefile(tmpfile.c_str());
+ std::stringstream ss_msg;
+ ss_msg << "Cannot write in temporary file " << tmpfile << std::endl
+ << " Probably there is not enough space in this filesystem " << std::endl
+ << " Eventually rerun quantize-lm by specifyng the pathname" << std::endl
+ << " of the temporary file to be used. ";
+ exit_error(IRSTLM_ERROR_IO, ss_msg.str());
+ }
+ int howmany = parseWords(line, words, Order + 3);
+ MY_ASSERT(howmany == Order+2 || howmany == Order+1);
+ sscanf(words[0],"%f",&logprob);
+ dataPts[nPts].pt=logprob; //exp(logprob * logten);
+ dataPts[nPts].idx=nPts;
+ }
+
+ cerr << "quantizing " << N << " probabilities\n";
+
+ centersP=new double[centers[Order]];
+ mapP=new unsigned short[N];
+
+ ComputeCluster(centers[Order],centersP,N,dataPts);
+
+
+ for (unsigned int p=0; p<N; p++) {
+ mapP[dataPts[p].idx]=dataPts[p].code;
+ }
+
+ if (Order<MaxOrder) {
+ //second pass to read back-off weights
+ //read from temporary file
+ filebuff.seekg((streampos)0);
+
+ for (nPts=0; nPts<N; nPts++) {
+
+ filebuff.getline(line,MAX_LINE);
+ int howmany = parseWords(line, words, Order + 3);
+ if (howmany==Order+2) //backoff is written
+ sscanf(words[Order+1],"%f",&logbow);
+ else
+ logbow=0; // backoff is implicit
+
+ dataPts[nPts].pt=logbow;
+ dataPts[nPts].idx=nPts;
+ }
+
+ centersB=new double[centers[Order]];
+ mapB=new unsigned short[N];
+
+ cerr << "quantizing " << N << " backoff weights\n";
+ ComputeCluster(centers[Order],centersB,N,dataPts);
+
+ for (unsigned int p=0; p<N; p++) {
+ mapB[dataPts[p].idx]=dataPts[p].code;
+ }
+
+ }
+
+
+ *out << centers[Order] << "\n";
+ for (int c=0; c<centers[Order]; c++) {
+ *out << centersP[c];
+ if (Order<MaxOrder) *out << " " << centersB[c];
+ *out << "\n";
+ }
+
+ filebuff.seekg(0);
+
+ for (nPts=0; nPts<numNgrams[Order]; nPts++) {
+
+ filebuff.getline(line,MAX_LINE);
+
+ parseWords(line, words, Order + 3);
+
+ *out << mapP[nPts];
+
+ for (int i=1; i<=Order; i++) *out << "\t" << words[i];
+
+ if (Order < MaxOrder) *out << "\t" << mapB[nPts];
+
+ *out << "\n";
+
+ }
+
+ if (mapP) {
+ delete [] mapP;
+ mapP=NULL;
+ }
+ if (mapB) {
+ delete [] mapB;
+ mapB=NULL;
+ }
+
+ if (centersP) {
+ delete [] centersP;
+ centersP=NULL;
+ }
+ if (centersB) {
+ delete [] centersB;
+ centersB=NULL;
+ }
+
+ delete [] dataPts;
+
+ continue;
+
+
+ }
+
+ }
+
+ *out << "\\end\\\n";
+ cerr << "---- done\n";
+
+ out->flush();
+
+ out->close();
+ inp.close();
+
+ removefile(tmpfile.c_str());
+}
+
+// Compute Clusters
+
+int ComputeCluster(int centers,double* ctrs,unsigned int N,DataItem* bintable)
+{
+
+
+ //cerr << "\nExecuting Clutering Algorithm: k=" << centers<< "\n";
+ double log10=log(10.0);
+
+ for (unsigned int i=0; i<N; i++) bintable[i].code=0;
+
+ //cout << "start sort \n";
+ qsort(bintable,N,sizeof(DataItem),cmpFloatEntry);
+
+ unsigned int different=1;
+
+ for (unsigned int i=1; i<N; i++)
+ if (bintable[i].pt!=bintable[i-1].pt)
+ different++;
+
+ unsigned int interval=different/centers;
+ if (interval==0) interval++;
+
+ unsigned int* population=new unsigned int[centers];
+ unsigned int* species=new unsigned int[centers];
+
+ //cerr << " Different entries=" << different
+ // << " Total Entries=" << N << " Bin Size=" << interval << "\n";
+
+ for (int i=0; i<centers; i++) {
+ population[i]=species[i]=0;
+ ctrs[i]=0;
+ }
+
+ // initial values: this should catch up very low values: -99
+ bintable[0].code=0;
+ population[0]=1;
+ species[0]=1;
+
+ int currcode=0;
+ different=1;
+
+ for (unsigned int i=1; i<N; i++) {
+
+ if ((bintable[i].pt!=bintable[i-1].pt)) {
+ different++;
+ if ((different % interval) == 0)
+ if ((currcode+1) < centers
+ &&
+ population[currcode]>0) {
+ currcode++;
+ }
+ }
+
+ if (bintable[i].pt == bintable[i-1].pt)
+ bintable[i].code=bintable[i-1].code;
+ else {
+ bintable[i].code=currcode;
+ species[currcode]++;
+ }
+
+ population[bintable[i].code]++;
+
+ MY_ASSERT(bintable[i].code < centers);
+
+ ctrs[bintable[i].code]=ctrs[bintable[i].code]+exp(bintable[i].pt * log10);
+
+ }
+
+ for (int i=0; i<centers; i++) {
+ if (population[i]>0)
+ ctrs[i]=log(ctrs[i]/population[i])/log10;
+ else
+ ctrs[i]=-99;
+
+ if (ctrs[i]<-99) {
+ cerr << "Warning: adjusting center with too small prob " << ctrs[i] << "\n";
+ ctrs[i]=-99;
+ }
+
+ cerr << i << " ctr " << ctrs[i] << " population " << population[i] << " species " << species[i] <<"\n";
+ }
+
+ cout.flush();
+
+ delete [] population;
+ delete [] species;
+
+
+ return 1;
+
+}
+
+//----------------------------------------------------------------------
+// Reading/Printing utilities
+// readPt - read a point from input stream into data storage
+// at position i. Returns false on error or EOF.
+// printPt - prints a points to output file
+//----------------------------------------------------------------------
+
diff --git a/src/score-lm.cpp b/src/score-lm.cpp
new file mode 100644
index 0000000..4cb8708
--- /dev/null
+++ b/src/score-lm.cpp
@@ -0,0 +1,122 @@
+/******************************************************************************
+IrstLM: IRST Language Model Toolkit
+Copyright (C) 2010 Christian Hardmeier, FBK-irst Trento, Italy
+
+This library is free software; you can redistribute it and/or
+modify it under the terms of the GNU Lesser General Public
+License as published by the Free Software Foundation; either
+version 2.1 of the License, or (at your option) any later version.
+
+This library is distributed in the hope that it will be useful,
+but WITHOUT ANY WARRANTY; without even the implied warranty of
+MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+Lesser General Public License for more details.
+
+You should have received a copy of the GNU Lesser General Public
+License along with this library; if not, write to the Free Software
+Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+
+******************************************************************************/
+
+#include <cstdlib>
+#include <cstring>
+#include <fstream>
+#include <iostream>
+#include <sstream>
+#include <string>
+#include "cmd.h"
+#include "util.h"
+#include "lmtable.h"
+#include "n_gram.h"
+
+using namespace irstlm;
+
+void print_help(int TypeFlag=0){
+ std::cerr << std::endl << "score-lm - scores sentences with a language model" << std::endl;
+ std::cerr << std::endl << "USAGE:" << std::endl
+ << " score-lm -lm <model> [options]" << std::endl;
+ std::cerr << std::endl << "OPTIONS:" << std::endl;
+ std::cerr << " -lm language model to use (must be specified)" << std::endl;
+ std::cerr << " -dub dictionary upper bound (default: 10000000" << std::endl;
+ std::cerr << " -level max level to load from the language models (default: 1000," << std::endl;
+ std::cerr << " meaning the actual LM order)" << std::endl;
+ std::cerr << " -mm 1 memory-mapped access to lm (default: 0)" << std::endl;
+ std::cerr << std::endl;
+
+ FullPrintParams(TypeFlag, 0, 1, stderr);
+}
+
+void usage(const char *msg = 0)
+{
+ if (msg){
+ std::cerr << msg << std::endl;
+ }
+ else{
+ print_help();
+ }
+}
+
+int main(int argc, char **argv)
+{
+ int mmap = 0;
+ int dub = 10000000;
+ int requiredMaxlev = 1000;
+ char *lm = NULL;
+
+ bool help=false;
+
+ DeclareParams((char*)
+ "lm", CMDSTRINGTYPE|CMDMSG, &lm, "language model to use (must be specified)",
+ "DictionaryUpperBound", CMDINTTYPE|CMDMSG, &dub, "dictionary upperbound to compute OOV word penalty: default 10^7",
+ "dub", CMDINTTYPE|CMDMSG, &dub, "dictionary upperbound to compute OOV word penalty: default 10^7",
+ "memmap", CMDINTTYPE|CMDMSG, &mmap, "uses memory map to read a binary LM",
+ "mm", CMDINTTYPE|CMDMSG, &mmap, "uses memory map to read a binary LM",
+ "level", CMDINTTYPE|CMDMSG, &requiredMaxlev, "maximum level to load from the LM; if value is larger than the actual LM order, the latter is taken",
+ "lev", CMDINTTYPE|CMDMSG, &requiredMaxlev, "maximum level to load from the LM; if value is larger than the actual LM order, the latter is taken",
+
+ "Help", CMDBOOLTYPE|CMDMSG, &help, "print this help",
+ "h", CMDBOOLTYPE|CMDMSG, &help, "print this help",
+
+ (char *)NULL
+ );
+
+ if (argc == 1){
+ usage();
+ exit_error(IRSTLM_NO_ERROR);
+ }
+
+ GetParams(&argc, &argv, (char*) NULL);
+
+ if (help){
+ usage();
+ exit_error(IRSTLM_NO_ERROR);
+ }
+
+
+ if(lm == NULL){
+ usage();
+ exit_error(IRSTLM_ERROR_DATA,"Missing parameter: please, specify the LM to use (-lm)");
+ }
+
+ std::ifstream lmstr(lm);
+ lmtable lmt;
+ lmt.setMaxLoadedLevel(requiredMaxlev);
+ lmt.load(lmstr, lm, NULL, mmap);
+ lmt.setlogOOVpenalty(dub);
+
+ for(;;) {
+ std::string line;
+ std::getline(std::cin, line);
+ if(!std::cin.good())
+ return !std::cin.eof();
+
+ std::istringstream linestr(line);
+ ngram ng(lmt.dict);
+
+ double logprob = .0;
+ while((linestr >> ng))
+ logprob += lmt.lprob(ng);
+
+ std::cout << logprob << std::endl;
+ }
+}
diff --git a/src/shiftlm.cpp b/src/shiftlm.cpp
new file mode 100644
index 0000000..7dc4633
--- /dev/null
+++ b/src/shiftlm.cpp
@@ -0,0 +1,830 @@
+/******************************************************************************
+IrstLM: IRST Language Model Toolkit
+Copyright (C) 2006 Marcello Federico, ITC-irst Trento, Italy
+
+This library is free software; you can redistribute it and/or
+modify it under the terms of the GNU Lesser General Public
+License as published by the Free Software Foundation; either
+version 2.1 of the License, or (at your option) any later version.
+
+This library is distributed in the hope that it will be useful,
+but WITHOUT ANY WARRANTY; without even the implied warranty of
+MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+Lesser General Public License for more details.
+
+You should have received a copy of the GNU Lesser General Public
+License along with this library; if not, write to the Free Software
+Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+
+******************************************************************************/
+
+#include <string.h>
+#include <stdio.h>
+#include <stdlib.h>
+#include <math.h>
+#include <sstream>
+#include "util.h"
+#include "mfstream.h"
+#include "mempool.h"
+#include "htable.h"
+#include "dictionary.h"
+#include "n_gram.h"
+#include "ngramtable.h"
+#include "ngramcache.h"
+#include "normcache.h"
+#include "interplm.h"
+#include "mdiadapt.h"
+#include "shiftlm.h"
+
+namespace irstlm {
+//
+//Shiftone interpolated language model
+//
+
+shiftone::shiftone(char* ngtfile,int depth,int prunefreq,TABLETYPE tt):
+ mdiadaptlm(ngtfile,depth,tt)
+{
+ cerr << "Creating LM with ShiftOne smoothing\n";
+ prunethresh=prunefreq;
+ cerr << "PruneThresh: " << prunethresh << "\n";
+
+ beta=1.0;
+
+};
+
+
+int shiftone::train()
+{
+ trainunigr();
+ return 1;
+}
+
+
+int shiftone::discount(ngram ng_,int size,double& fstar,double& lambda, int cv)
+{
+
+ ngram ng(dict);
+ ng.trans(ng_);
+
+ //cerr << "size:" << size << " ng:|" << ng <<"|\n";
+
+ if (size > 1) {
+
+ ngram history=ng;
+
+ if (ng.ckhisto(size) && get(history,size,size-1) && (history.freq>cv) &&
+ ((size < 3) || ((history.freq-cv) > prunethresh))) {
+
+ // this history is not pruned out
+
+ get(ng,size,size);
+ cv=(cv>ng.freq)?ng.freq:cv;
+
+ if (ng.freq > cv) {
+
+ fstar=(double)((double)(ng.freq - cv) - beta)/(double)(history.freq-cv);
+
+ lambda=beta * ((double)history.succ/(double)(history.freq-cv));
+
+ } else { // ng.freq == cv: do like if ng was deleted from the table
+
+ fstar=0.0;
+
+ lambda=beta * ((double)(history.succ-1)/ //one successor has disappeared!
+ (double)(history.freq-cv));
+
+ }
+
+ //cerr << "ngram :" << ng << "\n";
+
+ //check if the last word is OOV
+ if (*ng.wordp(1)==dict->oovcode()) {
+ lambda+=fstar;
+ fstar=0.0;
+ } else { //complete lambda with oovcode probability
+ *ng.wordp(1)=dict->oovcode();
+ if (get(ng,size,size))
+ lambda+=(double)((double)ng.freq - beta)/(double)(history.freq-cv);
+ }
+
+ } else {
+ fstar=0;
+ lambda=1;
+ }
+ } else {
+ fstar=unigr(ng);
+ lambda=0.0;
+ }
+
+ return 1;
+}
+
+
+
+
+//
+//Shiftbeta interpolated language model
+//
+
+shiftbeta::shiftbeta(char* ngtfile,int depth,int prunefreq,double b,TABLETYPE tt):
+ mdiadaptlm(ngtfile,depth,tt)
+{
+ cerr << "Creating LM with ShiftBeta smoothing\n";
+
+ if (b==-1.0 || (b < 1.0 && b >0.0)) {
+ beta=new double[lmsize()+1];
+ for (int l=lmsize(); l>1; l--)
+ beta[l]=b;
+ } else {
+ exit_error(IRSTLM_ERROR_DATA,"shiftbeta::shiftbeta beta must be < 1.0 and > 0");
+ }
+
+ prunethresh=prunefreq;
+ cerr << "PruneThresh: " << prunethresh << "\n";
+};
+
+
+
+int shiftbeta::train()
+{
+ ngram ng(dict);
+ int n1,n2;
+
+ trainunigr();
+
+ beta[1]=0.0;
+
+ for (int l=2; l<=lmsize(); l++) {
+
+ cerr << "level " << l << "\n";
+ n1=0;
+ n2=0;
+ scan(ng,INIT,l);
+ while(scan(ng,CONT,l)) {
+
+
+ if (l<lmsize()) {
+ //Computing succ1 statistics for this ngram
+ //to correct smoothing due to singleton pruning
+
+ ngram hg=ng;
+ get(hg,l,l);
+ int s1=0;
+ ngram ng2=hg;
+ ng2.pushc(0);
+
+ succscan(hg,ng2,INIT,l+1);
+ while(succscan(hg,ng2,CONT,l+1)) {
+ if (ng2.freq==1) s1++;
+ }
+ succ1(hg.link,s1);
+ }
+
+ //skip ngrams containing _OOV
+ if (l>1 && ng.containsWord(dict->OOV(),l)) {
+ //cerr << "skp ngram" << ng << "\n";
+ continue;
+ }
+
+ //skip n-grams containing </s> in context
+ if (l>1 && ng.containsWord(dict->EoS(),l-1)) {
+ //cerr << "skp ngram" << ng << "\n";
+ continue;
+ }
+
+ //skip 1-grams containing <s>
+ if (l==1 && ng.containsWord(dict->BoS(),l)) {
+ //cerr << "skp ngram" << ng << "\n";
+ continue;
+ }
+
+ if (ng.freq==1) n1++;
+ else if (ng.freq==2) n2++;
+
+ }
+ //compute statistics of shiftbeta smoothing
+ if (beta[l]==-1) {
+ if (n1>0)
+ beta[l]=(double)n1/(double)(n1 + 2 * n2);
+ else {
+ cerr << "no singletons! \n";
+ beta[l]=1.0;
+ }
+ }
+ cerr << beta[l] << "\n";
+ }
+
+ return 1;
+};
+
+
+
+int shiftbeta::discount(ngram ng_,int size,double& fstar,double& lambda, int cv)
+{
+
+ ngram ng(dict);
+ ng.trans(ng_);
+
+ if (size > 1) {
+
+ ngram history=ng;
+
+ if (ng.ckhisto(size) && get(history,size,size-1) && (history.freq>cv) &&
+
+ ((size < 3) || ((history.freq-cv) > prunethresh ))) {
+
+ // apply history pruning on trigrams only
+
+
+ if (get(ng,size,size) && (!prunesingletons() || ng.freq >1 || size<3)) {
+ cv=(cv>ng.freq)?ng.freq:cv;
+
+ if (ng.freq>cv) {
+
+ fstar=(double)((double)(ng.freq - cv) - beta[size])/(double)(history.freq-cv);
+
+ lambda=beta[size]*((double)history.succ/(double)(history.freq-cv));
+
+ if (size>=3 && prunesingletons()) // correction due to frequency pruning
+
+ lambda+=(1.0-beta[size]) * (double)succ1(history.link)/(double)(history.freq-cv);
+
+ // succ1(history.link) is not affected if ng.freq > cv
+
+ } else { // ng.freq == cv
+
+ fstar=0.0;
+
+ lambda=beta[size]*((double)(history.succ-1)/ //e` sparito il successore
+ (double)(history.freq-cv));
+
+ if (size>=3 && prunesingletons()) //take into account single event pruning
+ lambda+=(1.0-beta[size]) * (double)(succ1(history.link)-(cv==1 && ng.freq==1?1:0))
+ /(double)(history.freq-cv);
+ }
+ } else {
+
+ fstar=0.0;
+ lambda=beta[size]*(double)history.succ/(double)history.freq;
+
+ if (size>=3 && prunesingletons()) // correction due to frequency pruning
+ lambda+=(1.0-beta[size]) * (double)succ1(history.link)/(double)history.freq;
+
+ }
+
+ //cerr << "ngram :" << ng << "\n";
+
+ if (*ng.wordp(1)==dict->oovcode()) {
+ lambda+=fstar;
+ fstar=0.0;
+ } else {
+ *ng.wordp(1)=dict->oovcode();
+ if (get(ng,size,size) && (!prunesingletons() || ng.freq >1 || size<3))
+ lambda+=(double)((double)ng.freq - beta[size])/(double)(history.freq-cv);
+ }
+
+ } else {
+ fstar=0;
+ lambda=1;
+ }
+ } else {
+ fstar=unigr(ng);
+ lambda=0.0;
+ }
+
+ return 1;
+}
+
+//
+//Improved Kneser-Ney language model (previously ModifiedShiftBeta)
+//
+
+improvedkneserney::improvedkneserney(char* ngtfile,int depth,int prunefreq,TABLETYPE tt):
+ mdiadaptlm(ngtfile,depth,tt)
+{
+ cerr << "Creating LM with Improved Kneser-Ney smoothing\n";
+
+ prunethresh=prunefreq;
+ cerr << "PruneThresh: " << prunethresh << "\n";
+
+ beta[1][0]=0.0;
+ beta[1][1]=0.0;
+ beta[1][2]=0.0;
+
+};
+
+
+int improvedkneserney::train()
+{
+
+ trainunigr();
+
+ gencorrcounts();
+ gensuccstat();
+
+ ngram ng(dict);
+ int n1,n2,n3,n4;
+ int unover3=0;
+
+ oovsum=0;
+
+ for (int l=1; l<=lmsize(); l++) {
+
+ cerr << "level " << l << "\n";
+
+ cerr << "computing statistics\n";
+
+ n1=0;
+ n2=0;
+ n3=0,n4=0;
+
+ scan(ng,INIT,l);
+
+ while(scan(ng,CONT,l)) {
+
+ //skip ngrams containing _OOV
+ if (l>1 && ng.containsWord(dict->OOV(),l)) {
+ continue;
+ }
+
+ //skip n-grams containing </s> in context
+ if (l>1 && ng.containsWord(dict->EoS(),l-1)) {
+ continue;
+ }
+
+ //skip 1-grams containing <s>
+ if (l==1 && ng.containsWord(dict->BoS(),l)) {
+ continue;
+ }
+
+ ng.freq=mfreq(ng,l);
+
+ if (ng.freq==1) n1++;
+ else if (ng.freq==2) n2++;
+ else if (ng.freq==3) n3++;
+ else if (ng.freq==4) n4++;
+ if (l==1 && ng.freq >=3) unover3++;
+ }
+
+ if (l==1) {
+ cerr << " n1: " << n1 << " n2: " << n2 << " n3: " << n3 << " n4: " << n4 << " unover3: " << unover3 << "\n";
+ } else {
+ cerr << " n1: " << n1 << " n2: " << n2 << " n3: " << n3 << " n4: " << n4 << "\n";
+ }
+
+ if (n1 == 0 || n2 == 0 || n1 <= n2) {
+ std::stringstream ss_msg;
+ ss_msg << "Error: lower order count-of-counts cannot be estimated properly\n";
+ ss_msg << "Hint: use another smoothing method with this corpus.\n";
+ exit_error(IRSTLM_ERROR_DATA,ss_msg.str());
+ }
+
+ double Y=(double)n1/(double)(n1 + 2 * n2);
+ beta[0][l] = Y; //equivalent to 1 - 2 * Y * n2 / n1
+
+ if (n3 == 0 || n4 == 0 || n2 <= n3 || n3 <= n4 ){
+ cerr << "Warning: higher order count-of-counts cannot be estimated properly\n";
+ cerr << "Fixing this problem by resorting only on the lower order count-of-counts\n";
+
+ beta[1][l] = Y;
+ beta[2][l] = Y;
+ }
+ else{
+ beta[1][l] = 2 - 3 * Y * n3 / n2;
+ beta[2][l] = 3 - 4 * Y * n4 / n3;
+ }
+
+ if (beta[1][l] < 0){
+ cerr << "Warning: discount coefficient is negative \n";
+ cerr << "Fixing this problem by setting beta to 0 \n";
+ beta[1][l] = 0;
+
+ }
+
+
+ if (beta[2][l] < 0){
+ cerr << "Warning: discount coefficient is negative \n";
+ cerr << "Fixing this problem by setting beta to 0 \n";
+ beta[2][l] = 0;
+
+ }
+
+
+ if (l==1)
+ oovsum=beta[0][l] * (double) n1 + beta[1][l] * (double)n2 + beta[2][l] * (double)unover3;
+
+ cerr << beta[0][l] << " " << beta[1][l] << " " << beta[2][l] << "\n";
+ }
+
+ return 1;
+};
+
+
+
+int improvedkneserney::discount(ngram ng_,int size,double& fstar,double& lambda, int cv)
+{
+ ngram ng(dict);
+ ng.trans(ng_);
+
+ //cerr << "size:" << size << " ng:|" << ng <<"|\n";
+
+ if (size > 1) {
+
+ ngram history=ng;
+
+ //singleton pruning only on real counts!!
+ if (ng.ckhisto(size) && get(history,size,size-1) && (history.freq > cv) &&
+ ((size < 3) || ((history.freq-cv) > prunethresh ))) { // no history pruning with corrected counts!
+
+ int suc[3];
+ suc[0]=succ1(history.link);
+ suc[1]=succ2(history.link);
+ suc[2]=history.succ-suc[0]-suc[1];
+
+
+ if (get(ng,size,size) &&
+ (!prunesingletons() || mfreq(ng,size)>1 || size<3) &&
+ (!prunetopsingletons() || mfreq(ng,size)>1 || size<maxlevel())) {
+
+ ng.freq=mfreq(ng,size);
+
+ cv=(cv>ng.freq)?ng.freq:cv;
+
+ if (ng.freq>cv) {
+
+ double b=(ng.freq-cv>=3?beta[2][size]:beta[ng.freq-cv-1][size]);
+
+ fstar=(double)((double)(ng.freq - cv) - b)/(double)(history.freq-cv);
+
+ lambda=(beta[0][size] * suc[0] + beta[1][size] * suc[1] + beta[2][size] * suc[2])
+ /
+ (double)(history.freq-cv);
+
+ if ((size>=3 && prunesingletons()) ||
+ (size==maxlevel() && prunetopsingletons())) // correction due to frequency pruning
+
+ lambda+=(double)(suc[0] * (1-beta[0][size])) / (double)(history.freq-cv);
+
+ } else {
+ // ng.freq==cv
+
+ ng.freq>=3?suc[2]--:suc[ng.freq-1]--; //update successor stat
+
+ fstar=0.0;
+ lambda=(beta[0][size] * suc[0] + beta[1][size] * suc[1] + beta[2][size] * suc[2])
+ /
+ (double)(history.freq-cv);
+
+ if ((size>=3 && prunesingletons()) ||
+ (size==maxlevel() && prunetopsingletons())) // correction due to frequency pruning
+ lambda+=(double)(suc[0] * (1-beta[0][size])) / (double)(history.freq-cv);
+
+ ng.freq>=3?suc[2]++:suc[ng.freq-1]++; //resume successor stat
+ }
+ } else {
+ fstar=0.0;
+ lambda=(beta[0][size] * suc[0] + beta[1][size] * suc[1] + beta[2][size] * suc[2])
+ /
+ (double)(history.freq-cv);
+
+ if ((size>=3 && prunesingletons()) ||
+ (size==maxlevel() && prunetopsingletons())) // correction due to frequency pruning
+ lambda+=(double)(suc[0] * (1-beta[0][size])) / (double)(history.freq-cv);
+
+ }
+
+ //cerr << "ngram :" << ng << "\n";
+
+
+ if (*ng.wordp(1)==dict->oovcode()) {
+ lambda+=fstar;
+ fstar=0.0;
+ } else {
+ *ng.wordp(1)=dict->oovcode();
+ if (get(ng,size,size)) {
+ ng.freq=mfreq(ng,size);
+ if ((!prunesingletons() || ng.freq>1 || size<3) &&
+ (!prunetopsingletons() || ng.freq>1 || size<maxlevel())) {
+ double b=(ng.freq>=3?beta[2][size]:beta[ng.freq-1][size]);
+ lambda+=(double)(ng.freq - b)/(double)(history.freq-cv);
+ }
+ }
+ }
+ } else {
+ fstar=0;
+ lambda=1;
+ }
+ } else { // unigram case, no cross-validation
+
+ fstar=unigrIKN(ng);
+ lambda=0.0;
+ }
+
+ return 1;
+}
+
+ double improvedkneserney::unigrIKN(ngram ng)
+ {
+ int unigrtotfreq=(lmsize()>1)?btotfreq():totfreq();
+ double fstar=0.0;
+ if (get(ng,1,1))
+ fstar=(double) mfreq(ng,1)/(double)unigrtotfreq;
+ else {
+ std::stringstream ss_msg;
+ ss_msg << "Missing probability for word: " << dict->decode(*ng.wordp(1));
+ exit_error(IRSTLM_ERROR_DATA,ss_msg.str());
+ }
+ return fstar;
+ }
+
+ //
+ //Improved Shiftbeta language model (similar to Improved Kneser-Ney without corrected counts)
+ //
+
+ improvedshiftbeta::improvedshiftbeta(char* ngtfile,int depth,int prunefreq,TABLETYPE tt):
+ mdiadaptlm(ngtfile,depth,tt)
+ {
+ cerr << "Creating LM with Improved ShiftBeta smoothing\n";
+
+ prunethresh=prunefreq;
+ cerr << "PruneThresh: " << prunethresh << "\n";
+
+ beta[1][0]=0.0;
+ beta[1][1]=0.0;
+ beta[1][2]=0.0;
+
+ };
+
+
+ int improvedshiftbeta::train()
+ {
+
+ trainunigr();
+
+ gensuccstat();
+
+ ngram ng(dict);
+ int n1,n2,n3,n4;
+ int unover3=0;
+
+ oovsum=0;
+
+ for (int l=1; l<=lmsize(); l++) {
+
+ cerr << "level " << l << "\n";
+
+ cerr << "computing statistics\n";
+
+ n1=0;
+ n2=0;
+ n3=0,n4=0;
+
+ scan(ng,INIT,l);
+
+ while(scan(ng,CONT,l)) {
+
+ //skip ngrams containing _OOV
+ if (l>1 && ng.containsWord(dict->OOV(),l)) {
+ continue;
+ }
+
+ //skip n-grams containing </s> in context
+ if (l>1 && ng.containsWord(dict->EoS(),l-1)) {
+ continue;
+ }
+
+ //skip 1-grams containing <s>
+ if (l==1 && ng.containsWord(dict->BoS(),l)) {
+ continue;
+ }
+
+ ng.freq=mfreq(ng,l);
+
+ if (ng.freq==1) n1++;
+ else if (ng.freq==2) n2++;
+ else if (ng.freq==3) n3++;
+ else if (ng.freq==4) n4++;
+ if (l==1 && ng.freq >=3) unover3++;
+ }
+
+ if (l==1) {
+ cerr << " n1: " << n1 << " n2: " << n2 << " n3: " << n3 << " n4: " << n4 << " unover3: " << unover3 << "\n";
+ } else {
+ cerr << " n1: " << n1 << " n2: " << n2 << " n3: " << n3 << " n4: " << n4 << "\n";
+ }
+
+ if (n1 == 0 || n2 == 0 || n1 <= n2) {
+ std::stringstream ss_msg;
+ ss_msg << "Error: lower order count-of-counts cannot be estimated properly\n";
+ ss_msg << "Hint: use another smoothing method with this corpus.\n";
+ exit_error(IRSTLM_ERROR_DATA,ss_msg.str());
+ }
+
+ double Y=(double)n1/(double)(n1 + 2 * n2);
+ beta[0][l] = Y; //equivalent to 1 - 2 * Y * n2 / n1
+
+ if (n3 == 0 || n4 == 0 || n2 <= n3 || n3 <= n4 ){
+ cerr << "Warning: higher order count-of-counts cannot be estimated properly\n";
+ cerr << "Fixing this problem by resorting only on the lower order count-of-counts\n";
+
+ beta[1][l] = Y;
+ beta[2][l] = Y;
+ }
+ else{
+ beta[1][l] = 2 - 3 * Y * n3 / n2;
+ beta[2][l] = 3 - 4 * Y * n4 / n3;
+ }
+
+ if (beta[1][l] < 0){
+ cerr << "Warning: discount coefficient is negative \n";
+ cerr << "Fixing this problem by setting beta to 0 \n";
+ beta[1][l] = 0;
+
+ }
+
+
+ if (beta[2][l] < 0){
+ cerr << "Warning: discount coefficient is negative \n";
+ cerr << "Fixing this problem by setting beta to 0 \n";
+ beta[2][l] = 0;
+
+ }
+
+
+ if (l==1)
+ oovsum=beta[0][l] * (double) n1 + beta[1][l] * (double)n2 + beta[2][l] * (double)unover3;
+
+ cerr << beta[0][l] << " " << beta[1][l] << " " << beta[2][l] << "\n";
+ }
+
+ return 1;
+ };
+
+
+
+ int improvedshiftbeta::discount(ngram ng_,int size,double& fstar,double& lambda, int cv)
+ {
+ ngram ng(dict);
+ ng.trans(ng_);
+
+ //cerr << "size:" << size << " ng:|" << ng <<"|\n";
+
+ if (size > 1) {
+
+ ngram history=ng;
+
+ //singleton pruning only on real counts!!
+ if (ng.ckhisto(size) && get(history,size,size-1) && (history.freq > cv) &&
+ ((size < 3) || ((history.freq-cv) > prunethresh ))) { // no history pruning with corrected counts!
+
+ int suc[3];
+ suc[0]=succ1(history.link);
+ suc[1]=succ2(history.link);
+ suc[2]=history.succ-suc[0]-suc[1];
+
+
+ if (get(ng,size,size) &&
+ (!prunesingletons() || mfreq(ng,size)>1 || size<3) &&
+ (!prunetopsingletons() || mfreq(ng,size)>1 || size<maxlevel())) {
+
+ ng.freq=mfreq(ng,size);
+
+ cv=(cv>ng.freq)?ng.freq:cv;
+
+ if (ng.freq>cv) {
+
+ double b=(ng.freq-cv>=3?beta[2][size]:beta[ng.freq-cv-1][size]);
+
+ fstar=(double)((double)(ng.freq - cv) - b)/(double)(history.freq-cv);
+
+ lambda=(beta[0][size] * suc[0] + beta[1][size] * suc[1] + beta[2][size] * suc[2])
+ /
+ (double)(history.freq-cv);
+
+ if ((size>=3 && prunesingletons()) ||
+ (size==maxlevel() && prunetopsingletons())) // correction due to frequency pruning
+
+ lambda+=(double)(suc[0] * (1-beta[0][size])) / (double)(history.freq-cv);
+
+ } else {
+ // ng.freq==cv
+
+ ng.freq>=3?suc[2]--:suc[ng.freq-1]--; //update successor stat
+
+ fstar=0.0;
+ lambda=(beta[0][size] * suc[0] + beta[1][size] * suc[1] + beta[2][size] * suc[2])
+ /
+ (double)(history.freq-cv);
+
+ if ((size>=3 && prunesingletons()) ||
+ (size==maxlevel() && prunetopsingletons())) // correction due to frequency pruning
+ lambda+=(double)(suc[0] * (1-beta[0][size])) / (double)(history.freq-cv);
+
+ ng.freq>=3?suc[2]++:suc[ng.freq-1]++; //resume successor stat
+ }
+ } else {
+ fstar=0.0;
+ lambda=(beta[0][size] * suc[0] + beta[1][size] * suc[1] + beta[2][size] * suc[2])
+ /
+ (double)(history.freq-cv);
+
+ if ((size>=3 && prunesingletons()) ||
+ (size==maxlevel() && prunetopsingletons())) // correction due to frequency pruning
+ lambda+=(double)(suc[0] * (1-beta[0][size])) / (double)(history.freq-cv);
+
+ }
+
+ //cerr << "ngram :" << ng << "\n";
+
+
+ if (*ng.wordp(1)==dict->oovcode()) {
+ lambda+=fstar;
+ fstar=0.0;
+ } else {
+ *ng.wordp(1)=dict->oovcode();
+ if (get(ng,size,size)) {
+ ng.freq=mfreq(ng,size);
+ if ((!prunesingletons() || ng.freq>1 || size<3) &&
+ (!prunetopsingletons() || ng.freq>1 || size<maxlevel())) {
+ double b=(ng.freq>=3?beta[2][size]:beta[ng.freq-1][size]);
+ lambda+=(double)(ng.freq - b)/(double)(history.freq-cv);
+ }
+ }
+ }
+ } else {
+ fstar=0;
+ lambda=1;
+ }
+ } else { // unigram case, no cross-validation
+ fstar=unigr(ng);
+ lambda=0;
+ }
+
+ return 1;
+ }
+
+//Symmetric Shiftbeta
+int symshiftbeta::discount(ngram ng_,int size,double& fstar,double& lambda, int /* unused parameter: cv */)
+{
+ ngram ng(dict);
+ ng.trans(ng_);
+
+ //cerr << "size:" << size << " ng:|" << ng <<"|\n";
+
+ // Pr(x/y)= max{(c([x,y])-beta)/(N Pr(y)),0} + lambda Pr(x)
+ // lambda=#bigrams/N
+
+ MY_ASSERT(size<=2); // only works with bigrams //
+
+ if (size == 3) {
+
+ ngram history=ng;
+
+
+ }
+ if (size == 2) {
+
+ //compute unigram probability of denominator
+ ngram unig(dict,1);
+ *unig.wordp(1)=*ng.wordp(2);
+ double prunig=unigr(unig);
+
+ //create symmetric bigram
+ if (*ng.wordp(1) > *ng.wordp(2)) {
+ int tmp=*ng.wordp(1);
+ *ng.wordp(1)=*ng.wordp(2);
+ *ng.wordp(2)=tmp;
+ }
+
+ lambda=beta[2] * (double) entries(2)/(double)totfreq();
+
+ if (get(ng,2,2)) {
+ fstar=(double)((double)ng.freq - beta[2])/
+ (totfreq() * prunig);
+ } else {
+ fstar=0;
+ }
+ } else {
+ fstar=unigr(ng);
+ lambda=0.0;
+ }
+ return 1;
+}
+
+}//namespace irstlm
+
+
+/*
+main(int argc, char** argv){
+ dictionary d(argv[1]);
+
+ shiftbeta ilm(&d,argv[2],3);
+
+ ngramtable test(&d,argv[2],3);
+ ilm.train();
+ cerr << "PP " << ilm.test(test) << "\n";
+
+ ilm.savebin("newlm.lm",3);
+}
+
+*/
diff --git a/src/shiftlm.h b/src/shiftlm.h
new file mode 100644
index 0000000..51a51ab
--- /dev/null
+++ b/src/shiftlm.h
@@ -0,0 +1,108 @@
+/******************************************************************************
+IrstLM: IRST Language Model Toolkit
+Copyright (C) 2006 Marcello Federico, ITC-irst Trento, Italy
+
+This library is free software; you can redistribute it and/or
+modify it under the terms of the GNU Lesser General Public
+License as published by the Free Software Foundation; either
+version 2.1 of the License, or (at your option) any later version.
+
+This library is distributed in the hope that it will be useful,
+but WITHOUT ANY WARRANTY; without even the implied warranty of
+MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+Lesser General Public License for more details.
+
+You should have received a copy of the GNU Lesser General Public
+License along with this library; if not, write to the Free Software
+Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+
+******************************************************************************/
+
+namespace irstlm {
+
+// Non linear Shift based interpolated LMs
+
+class shiftone: public mdiadaptlm
+{
+protected:
+ int prunethresh;
+ double beta;
+public:
+ shiftone(char* ngtfile,int depth=0,int prunefreq=0,TABLETYPE tt=SHIFTBETA_B);
+ int train();
+ int discount(ngram ng,int size,double& fstar,double& lambda,int cv=0);
+ ~shiftone() {}
+};
+
+
+class shiftbeta: public mdiadaptlm
+{
+protected:
+ int prunethresh;
+ double* beta;
+
+public:
+ shiftbeta(char* ngtfile,int depth=0,int prunefreq=0,double beta=-1,TABLETYPE tt=SHIFTBETA_B);
+ int train();
+ int discount(ngram ng,int size,double& fstar,double& lambda,int cv=0);
+ ~shiftbeta() {
+ delete [] beta;
+ }
+
+};
+
+
+class symshiftbeta: public shiftbeta
+{
+public:
+ symshiftbeta(char* ngtfile,int depth=0,int prunefreq=0,double beta=-1):
+ shiftbeta(ngtfile,depth,prunefreq,beta) {}
+ int discount(ngram ng,int size,double& fstar,double& lambda,int cv=0);
+};
+
+
+ class improvedkneserney: public mdiadaptlm
+ {
+ protected:
+ int prunethresh;
+ double beta[3][MAX_NGRAM];
+ ngramtable* tb[MAX_NGRAM];
+
+ double oovsum;
+
+ public:
+ improvedkneserney(char* ngtfile,int depth=0,int prunefreq=0,TABLETYPE tt=IMPROVEDKNESERNEY_B);
+ int train();
+ int discount(ngram ng,int size,double& fstar,double& lambda,int cv=0);
+
+ ~improvedkneserney() {}
+
+ int mfreq(ngram& ng,int l) {
+ return (l<lmsize()?getfreq(ng.link,ng.pinfo,1):ng.freq);
+ }
+
+ double unigrIKN(ngram ng);
+ inline double unigr(ngram ng){ return unigrIKN(ng); };
+ };
+
+class improvedshiftbeta: public mdiadaptlm
+{
+protected:
+ int prunethresh;
+ double beta[3][MAX_NGRAM];
+ ngramtable* tb[MAX_NGRAM];
+
+ double oovsum;
+
+public:
+ improvedshiftbeta(char* ngtfile,int depth=0,int prunefreq=0,TABLETYPE tt=IMPROVEDSHIFTBETA_B);
+ int train();
+ int discount(ngram ng,int size,double& fstar,double& lambda,int cv=0);
+
+ ~improvedshiftbeta() {}
+
+ inline int mfreq(ngram& ng,int /*NOT_USED l*/) { return ng.freq; }
+
+};
+
+}//namespace irstlm
diff --git a/src/stream-tlm.cpp b/src/stream-tlm.cpp
new file mode 100644
index 0000000..f38d177
--- /dev/null
+++ b/src/stream-tlm.cpp
@@ -0,0 +1,575 @@
+/******************************************************************************
+IrstLM: IRST Language Model Toolkit
+Copyright (C) 2006 Marcello Federico, ITC-irst Trento, Italy
+
+This library is free software; you can redistribute it and/or
+modify it under the terms of the GNU Lesser General Public
+License as published by the Free Software Foundation; either
+version 2.1 of the License, or (at your option) any later version.
+
+This library is distributed in the hope that it will be useful,
+but WITHOUT ANY WARRANTY; without even the implied warranty of
+MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+Lesser General Public License for more details.
+
+You should have received a copy of the GNU Lesser General Public
+License along with this library; if not, write to the Free Software
+Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+
+******************************************************************************/
+
+using namespace std;
+
+#include <cmath>
+#include <math.h>
+#include "mfstream.h"
+#include <fstream>
+#include <stdio.h>
+#include <iostream>
+#include "mempool.h"
+#include "htable.h"
+#include "dictionary.h"
+#include "n_gram.h"
+#include "mempool.h"
+#include "ngramcache.h"
+#include "ngramtable.h"
+#include "interplm.h"
+#include "normcache.h"
+#include "mdiadapt.h"
+#include "shiftlm.h"
+#include "linearlm.h"
+#include "mixture.h"
+#include "cmd.h"
+#include "lmtable.h"
+
+
+#define YES 1
+#define NO 0
+
+
+#define NGRAM 1
+#define SEQUENCE 2
+#define ADAPT 3
+#define TURN 4
+#define TEXT 5
+
+
+#define END_ENUM { (char*)0, 0 }
+
+static Enum_T BooleanEnum [] = {
+ { "Yes", YES },
+ { "No", NO},
+ { "yes", YES },
+ { "no", NO},
+ { "y", YES },
+ { "n", NO},
+ END_ENUM
+};
+
+static Enum_T LmTypeEnum [] = {
+ { "ImprovedKneserNey", IMPROVED_KNESER_NEY },
+ { "ikn", IMPROVED_KNESER_NEY },
+ { "KneserNey", KNESER_NEY },
+ { "kn", KNESER_NEY },
+ { "ModifiedShiftBeta", MOD_SHIFT_BETA },
+ { "msb", MOD_SHIFT_BETA },
+ { "ImprovedShiftBeta", IMPROVED_SHIFT_BETA },
+ { "isb", IMPROVED_SHIFT_BETA },
+ { "InterpShiftBeta", SHIFT_BETA },
+ { "ShiftBeta", SHIFT_BETA },
+ { "sb", SHIFT_BETA },
+ { "InterpShiftOne", SHIFT_ONE },
+ { "ShiftOne", SHIFT_ONE },
+ { "s1", SHIFT_ONE },
+ { "LinearWittenBell", LINEAR_WB },
+ { "wb", LINEAR_WB },
+ { "LinearGoodTuring", LINEAR_GT },
+ { "Mixture", MIXTURE },
+ { "mix", MIXTURE },
+ END_ENUM
+};
+
+
+#define RESET 1
+#define SAVE 2
+#define LOAD 3
+#define INIT 4
+#define STOP 5
+
+#define BIN 11
+#define ARPA 12
+#define ASR 13
+#define TXT 14
+#define NGT 15
+
+
+int init(mdiadaptlm** lm, int lmtype, char *trainfile, int size, int prunefreq, double beta, int backoff, int dub, double oovrate, int mcl);
+int deinit(mdiadaptlm** lm);
+
+int main(int argc, char **argv)
+{
+
+ char *dictfile=NULL;
+ char *trainfile=NULL;
+
+ char *BINfile=NULL;
+ char *ARPAfile=NULL;
+ char *ASRfile=NULL;
+
+ int backoff=0; //back-off or interpolation
+ int lmtype=0;
+ int dub=0; //dictionary upper bound
+ int size=0; //lm size
+
+ int statistics=0;
+
+ int prunefreq=NO;
+ int prunesingletons=YES;
+ int prunetopsingletons=NO;
+
+ double beta=-1;
+
+ int compsize=NO;
+ int checkpr=NO;
+ double oovrate=0;
+ int max_caching_level=0;
+
+ char *outpr=NULL;
+
+ int memmap = 0; //write binary format with/without memory map, default is 0
+
+ DeclareParams(
+
+ "Back-off",CMDENUMTYPE, &backoff, BooleanEnum,
+ "bo",CMDENUMTYPE, &backoff, BooleanEnum,
+
+ "Dictionary", CMDSTRINGTYPE, &dictfile,
+ "d", CMDSTRINGTYPE, &dictfile,
+
+ "DictionaryUpperBound", CMDINTTYPE, &dub,
+ "dub", CMDINTTYPE, &dub,
+
+ "NgramSize", CMDSUBRANGETYPE, &size, 1 , MAX_NGRAM,
+ "n", CMDSUBRANGETYPE, &size, 1 , MAX_NGRAM,
+
+ "Ngram", CMDSTRINGTYPE, &trainfile,
+ "TrainOn", CMDSTRINGTYPE, &trainfile,
+ "tr", CMDSTRINGTYPE, &trainfile,
+
+ "oASR", CMDSTRINGTYPE, &ASRfile,
+ "oasr", CMDSTRINGTYPE, &ASRfile,
+
+ "o", CMDSTRINGTYPE, &ARPAfile,
+ "oARPA", CMDSTRINGTYPE, &ARPAfile,
+ "oarpa", CMDSTRINGTYPE, &ARPAfile,
+
+ "oBIN", CMDSTRINGTYPE, &BINfile,
+ "obin", CMDSTRINGTYPE, &BINfile,
+
+ "LanguageModelType",CMDENUMTYPE, &lmtype, LmTypeEnum,
+ "lm",CMDENUMTYPE, &lmtype, LmTypeEnum,
+
+ "Statistics",CMDSUBRANGETYPE, &statistics, 1 , 3,
+ "s",CMDSUBRANGETYPE, &statistics, 1 , 3,
+
+ "PruneThresh",CMDSUBRANGETYPE, &prunefreq, 1 , 1000,
+ "p",CMDSUBRANGETYPE, &prunefreq, 1 , 1000,
+
+ "PruneSingletons",CMDENUMTYPE, &prunesingletons, BooleanEnum,
+ "ps",CMDENUMTYPE, &prunesingletons, BooleanEnum,
+
+ "PruneTopSingletons",CMDENUMTYPE, &prunetopsingletons, BooleanEnum,
+ "pts",CMDENUMTYPE, &prunetopsingletons, BooleanEnum,
+
+ "ComputeLMSize",CMDENUMTYPE, &compsize, BooleanEnum,
+ "sz",CMDENUMTYPE, &compsize, BooleanEnum,
+
+ "MaximumCachingLevel", CMDINTTYPE , &max_caching_level,
+ "mcl", CMDINTTYPE, &max_caching_level,
+
+ "MemoryMap", CMDENUMTYPE, &memmap, BooleanEnum,
+ "memmap", CMDENUMTYPE, &memmap, BooleanEnum,
+ "mm", CMDENUMTYPE, &memmap, BooleanEnum,
+
+ "CheckProb",CMDENUMTYPE, &checkpr, BooleanEnum,
+ "cp",CMDENUMTYPE, &checkpr, BooleanEnum,
+
+ "OutProb",CMDSTRINGTYPE, &outpr,
+ "op",CMDSTRINGTYPE, &outpr,
+
+ "SetOovRate", CMDDOUBLETYPE, &oovrate,
+ "or", CMDDOUBLETYPE, &oovrate,
+
+ "Beta", CMDDOUBLETYPE, &beta,
+ "beta", CMDDOUBLETYPE, &beta,
+
+ (char *)NULL
+ );
+
+ GetParams(&argc, &argv, (char*) NULL);
+
+ if (!lmtype) {
+ cerr <<"Missing parameters\n";
+ exit(1);
+ }
+
+
+ cerr <<"LM size: " << size << "\n";
+
+
+ char header[BUFSIZ];
+ char filename[BUFSIZ];
+ int cmdcounter=0;
+ mdiadaptlm *lm=NULL;
+
+
+ int cmdtype=INIT;
+ int filetype=0;
+ int BoSfreq=0;
+
+ init(&lm, lmtype, trainfile, size, prunefreq, beta, backoff, dub, oovrate, max_caching_level);
+
+ ngram ng(lm->dict), ng2(lm->dict);
+
+ cerr << "filling the initial n-grams with BoS\n";
+ for (int i=1; i<lm->maxlevel(); i++) {
+ ng.pushw(lm->dict->BoS());
+ ng.freq=1;
+ }
+
+ mfstream inp("/dev/stdin",ios::in );
+ int c=0;
+
+ while (inp >> header) {
+
+ if (strncmp(header,"@CMD@",5)==0) {
+ cmdcounter++;
+ inp >> header;
+
+ cerr << "Read |@CMD@| |" << header << "|";
+
+ cmdtype=INIT;
+ filetype=BIN;
+ if (strncmp(header,"RESET",5)==0) cmdtype=RESET;
+ else if (strncmp(header,"INIT",4)==0) cmdtype=INIT;
+ else if (strncmp(header,"SAVEBIN",7)==0) {
+ cmdtype=SAVE;
+ filetype=BIN;
+ } else if (strncmp(header,"SAVEARPA",8)==0) {
+ cmdtype=SAVE;
+ filetype=ARPA;
+ } else if (strncmp(header,"SAVEASR",7)==0) {
+ cmdtype=SAVE;
+ filetype=ASR;
+ } else if (strncmp(header,"SAVENGT",7)==0) {
+ cmdtype=SAVE;
+ filetype=NGT;
+ } else if (strncmp(header,"LOADNGT",7)==0) {
+ cmdtype=LOAD;
+ filetype=NGT;
+ } else if (strncmp(header,"LOADTXT",7)==0) {
+ cmdtype=LOAD;
+ filetype=TXT;
+ } else if (strncmp(header,"STOP",4)==0) cmdtype=STOP;
+ else {
+ cerr << "CMD " << header << " is unknown\n";
+ exit(1);
+ }
+
+ char** lastwords;
+ char *isym;
+ switch (cmdtype) {
+
+ case STOP:
+ cerr << "\n";
+ exit(1);
+ break;
+
+ case SAVE:
+
+ inp >> filename; //storing the output filename
+ cerr << " |" << filename << "|\n";
+
+ //save actual ngramtable
+ char tmpngtfile[BUFSIZ];
+ sprintf(tmpngtfile,"%s.ngt%d",filename,cmdcounter);
+ cerr << "saving temporary ngramtable (binary)..." << tmpngtfile << "\n";
+ ((ngramtable*) lm)->ngtype("ngram");
+ ((ngramtable*) lm)->savetxt(tmpngtfile,size);
+
+ //get the actual frequency of BoS symbol, because the constructor of LM will reset to 1;
+ BoSfreq=lm->dict->freq(lm->dict->encode(lm->dict->BoS()));
+
+ lm->train();
+
+ lm->prunesingletons(prunesingletons==YES);
+ lm->prunetopsingletons(prunetopsingletons==YES);
+
+ if (prunetopsingletons==YES) //keep most specific
+ lm->prunesingletons(NO);
+
+
+ switch (filetype) {
+
+ case BIN:
+ cerr << "saving lm (binary) ... " << filename << "\n";
+ lm->saveBIN(filename,backoff,dictfile,memmap);
+ cerr << "\n";
+ break;
+
+ case ARPA:
+ cerr << "save lm (ARPA)... " << filename << "\n";
+ lm->saveARPA(filename,backoff,dictfile);
+ cerr << "\n";
+ break;
+
+ case ASR:
+ cerr << "save lm (ASR)... " << filename << "\n";
+ lm->saveASR(filename,backoff,dictfile);
+ cerr << "\n";
+ break;
+
+ case NGT:
+ cerr << "save the ngramtable on ... " << filename << "\n";
+ {
+ ifstream ifs(tmpngtfile, ios::binary);
+ std::ofstream ofs(filename, std::ios::binary);
+ ofs << ifs.rdbuf();
+ }
+ cerr << "\n";
+ break;
+
+ default:
+ cerr << "Saving type is unknown\n";
+ exit(1);
+ };
+
+ //store last words up to the LM order (filling with BoS if needed)
+ ng.size=(ng.size>lm->maxlevel())?lm->maxlevel():ng.size;
+ lastwords = new char*[lm->maxlevel()];
+
+ for (int i=1; i<lm->maxlevel(); i++) {
+ lastwords[i] = new char[BUFSIZ];
+ if (i<=ng.size)
+ strcpy(lastwords[i],lm->dict->decode(*ng.wordp(i)));
+ else
+ strcpy(lastwords[i],lm->dict->BoS());
+ }
+
+ deinit(&lm);
+
+ init(&lm, lmtype, tmpngtfile, size, prunefreq, beta, backoff, dub, oovrate, max_caching_level);
+ if (remove(tmpngtfile) != 0)
+ cerr << "Error deleting file " << tmpngtfile << endl;
+ else
+ cerr << "File " << tmpngtfile << " successfully deleted" << endl;
+
+ //re-set the dictionaries of the working ngrams and re-encode the actual ngram
+ ng.dict=ng2.dict=lm->dict;
+ ng.size=lm->maxlevel();
+
+ //restore the last words re-encoded wrt to the new dictionary
+ for (int i=1; i<lm->maxlevel(); i++) {
+ *ng.wordp(i)=lm->dict->encode(lastwords[i]);
+ delete []lastwords[i];
+ }
+ delete []lastwords;
+
+
+ //re-set the actual frequency of BoS symbol, because the constructor of LM deleted it;
+ lm->dict->freq(lm->dict->encode(lm->dict->BoS()), BoSfreq);
+ break;
+
+
+ case RESET: //restart from scratch
+ deinit(&lm);
+
+ init(&lm, lmtype, NULL, size, prunefreq, beta, backoff, dub, oovrate, max_caching_level);
+
+ ng.dict=ng2.dict=lm->dict;
+ cerr << "filling the initial n-grams with BoS\n";
+ for (int i=1; i<lm->maxlevel(); i++) {
+ ng.pushw(lm->dict->BoS());
+ ng.freq=1;
+ }
+ break;
+
+
+ case INIT:
+ cerr << "CMD " << header << " not yet implemented\n";
+ exit(1);
+ break;
+
+ case LOAD:
+ inp >> filename; //storing the input filename
+ cerr << " |" << filename << "|\n";
+
+
+ isym=new char[BUFSIZ];
+ strcpy(isym,lm->dict->EoS());
+ ngramtable* ngt;
+
+ switch (filetype) {
+
+ case NGT:
+ cerr << "loading an ngramtable..." << filename << "\n";
+ ngt = new ngramtable(filename,size,isym,NULL,NULL);
+ ((ngramtable*) lm)->augment(ngt);
+ cerr << "\n";
+ break;
+
+ case TXT:
+ cerr << "loading from text..." << filename << "\n";
+ ngt= new ngramtable(filename,size,isym,NULL,NULL);
+ ((ngramtable*) lm)->augment(ngt);
+ cerr << "\n";
+ break;
+
+ default:
+ cerr << "This file type is unknown\n";
+ exit(1);
+ };
+
+ break;
+
+ default:
+ cerr << "CMD " << header << " is unknown\n";
+ exit(1);
+ };
+ } else {
+ ng.pushw(header);
+
+ // CHECK: serve questa trans()
+ ng2.trans(ng); //reencode with new dictionary
+
+ lm->check_dictsize_bound();
+
+ //CHECK: e' corretto ng.size? non dovrebbe essere ng2.size?
+ if (ng.size) lm->dict->incfreq(*ng2.wordp(1),1);
+ //CHECK: what about filtering dictionary???
+ /*
+ if (filterdict){
+ int code=filterdict->encode(dict->decode(*ng2.wordp(maxlev)));
+ if (code!=filterdict->oovcode()) put(ng2);
+ }
+ else put(ng2);
+ */
+
+ lm->put(ng2);
+
+ if (!(++c % 1000000)) cerr << ".";
+ }
+ }
+
+ if (statistics) {
+ cerr << "TLM: lm stat ...";
+ lm->lmstat(statistics);
+ cerr << "\n";
+ }
+
+ cerr << "TLM: deleting lm ...";
+ //delete lm;
+ cerr << "\n";
+
+ exit(0);
+}
+
+int init(mdiadaptlm** lm, int lmtype, char *trainfile, int size, int prunefreq, double beta, int backoff, int dub, double oovrate, int mcl)
+{
+
+ cerr << "initializing lm... \n";
+ if (trainfile) cerr << "creating lm from " << trainfile << "\n";
+ else cerr << "creating an empty lm\n";
+ switch (lmtype) {
+
+ case SHIFT_BETA:
+ if (beta==-1 || (beta<1.0 && beta>0))
+ *lm=new shiftbeta(trainfile,size,prunefreq,beta,(backoff?SHIFTBETA_B:SHIFTBETA_I));
+ else {
+ cerr << "ShiftBeta: beta must be >0 and <1\n";
+ exit(1);
+ }
+ break;
+
+ case KNESER_NEY:
+ if (size>1){
+ if (beta==-1 || (beta<1.0 && beta>0)){
+// lm=new kneserney(trainfile,size,prunefreq,beta,(backoff?KNESERNEY_B:KNESERNEY_I));
+ } else {
+ exit_error(IRSTLM_ERROR_DATA,"ShiftBeta: beta must be >0 and <1");
+ }
+ } else {
+ exit_error(IRSTLM_ERROR_DATA,"Kneser-Ney requires size >1");
+ }
+ break;
+
+ case MOD_SHIFT_BETA:
+ cerr << "ModifiedShiftBeta (msb) is the old name for ImprovedKneserNey (ikn); this name is not supported anymore, but it is mapped into ImprovedKneserNey for back-compatibility";
+ case IMPROVED_KNESER_NEY:
+ if (size>1){
+ lm=new improvedkneserney(trainfile,size,prunefreq,(backoff?IMPROVEDKNESERNEY_B:IMPROVEDKNESERNEY_I));
+ } else {
+ exit_error(IRSTLM_ERROR_DATA,"Improved Kneser-Ney requires size >1");
+ }
+ break;
+
+ case IMPROVED_SHIFT_BETA:
+ lm=new improvedshiftbeta(trainfile,size,prunefreq,(backoff?IMPROVEDSHIFTBETA_B:IMPROVEDSHIFTBETA_I));
+ break;
+
+ case SHIFT_ONE:
+ *lm=new shiftone(trainfile,size,prunefreq,(backoff?SIMPLE_B:SIMPLE_I));
+ break;
+
+ case LINEAR_WB:
+ *lm=new linearwb(trainfile,size,prunefreq,(backoff?MSHIFTBETA_B:MSHIFTBETA_I));
+ break;
+
+ case LINEAR_GT:
+ cerr << "This LM is no more supported\n";
+ break;
+
+ case MIXTURE:
+ cerr << "not implemented yet\n";
+ break;
+
+ default:
+ cerr << "not implemented yet\n";
+ exit(1);
+ };
+
+ if (dub) (*lm)->dub(dub);
+ (*lm)->create_caches(mcl);
+
+ cerr << "eventually generate OOV code\n";
+ (*lm)->dict->genoovcode();
+
+ if (oovrate) (*lm)->dict->setoovrate(oovrate);
+
+ (*lm)->dict->incflag(1);
+
+ if (!trainfile) {
+ cerr << "adding the initial dummy n-grams to make table consistent\n";
+
+ ngram dummyng((*lm)->dict);
+ cerr << "preparing initial dummy n-grams\n";
+ for (int i=1; i<(*lm)->maxlevel(); i++) {
+ dummyng.pushw((*lm)->dict->BoS());
+ dummyng.freq=1;
+ }
+ cerr << "inside init: dict: " << (*lm)->dict << " dictsize: " << (*lm)->dict->size() << "\n";
+ cerr << "dummyng: |" << dummyng << "\n";
+ (*lm)->put(dummyng);
+ cerr << "inside init: dict: " << (*lm)->dict << " dictsize: " << (*lm)->dict->size() << "\n";
+
+ }
+
+ cerr << "lm initialized \n";
+ return 1;
+}
+
+int deinit(mdiadaptlm** lm)
+{
+ delete *lm;
+ return 1;
+}
diff --git a/src/thpool.c b/src/thpool.c
new file mode 100644
index 0000000..0ecf093
--- /dev/null
+++ b/src/thpool.c
@@ -0,0 +1,551 @@
+/* ********************************
+ * Author: Johan Hanssen Seferidis
+ * License: MIT
+ * Description: Library providing a threading pool where you can add
+ * work. For usage, check the thpool.h file or README.md
+ *
+ *//** @file thpool.h *//*
+ *
+ ********************************/
+
+
+#include <unistd.h>
+#include <signal.h>
+#include <stdio.h>
+#include <stdlib.h>
+#include <pthread.h>
+#include <errno.h>
+#include <time.h>
+
+#include "thpool.h"
+
+#ifdef THPOOL_DEBUG
+#define THPOOL_DEBUG 1
+#else
+#define THPOOL_DEBUG 0
+#endif
+
+#define MAX_NANOSEC 999999999
+#define CEIL(X) ((X-(int)(X)) > 0 ? (int)(X+1) : (int)(X))
+
+static volatile int threads_keepalive;
+static volatile int threads_on_hold;
+
+
+
+
+
+/* ========================== STRUCTURES ============================ */
+
+
+/* Binary semaphore */
+typedef struct bsem {
+ pthread_mutex_t mutex;
+ pthread_cond_t cond;
+ int v;
+} bsem;
+
+
+/* Job */
+typedef struct job{
+ struct job* prev; /* pointer to previous job */
+ void* (*function)(void* arg); /* function pointer */
+ void* arg; /* function's argument */
+} job;
+
+
+/* Job queue */
+typedef struct jobqueue{
+ pthread_mutex_t rwmutex; /* used for queue r/w access */
+ job *front; /* pointer to front of queue */
+ job *rear; /* pointer to rear of queue */
+ bsem *has_jobs; /* flag as binary semaphore */
+ int len; /* number of jobs in queue */
+} jobqueue;
+
+
+/* Thread */
+typedef struct thread{
+ int id; /* friendly id */
+ pthread_t pthread; /* pointer to actual thread */
+ struct thpool_* thpool_p; /* access to thpool */
+} thread;
+
+
+/* Threadpool */
+typedef struct thpool_{
+ thread** threads; /* pointer to threads */
+ volatile int num_threads_alive; /* threads currently alive */
+ volatile int num_threads_working; /* threads currently working */
+ pthread_mutex_t thcount_lock; /* used for thread count etc */
+ jobqueue* jobqueue_p; /* pointer to the job queue */
+} thpool_;
+
+
+
+
+
+/* ========================== PROTOTYPES ============================ */
+
+
+static void thread_init(thpool_* thpool_p, struct thread** thread_p, int id);
+static void* thread_do(struct thread* thread_p);
+static void thread_hold();
+static void thread_destroy(struct thread* thread_p);
+
+static int jobqueue_init(thpool_* thpool_p);
+static void jobqueue_clear(thpool_* thpool_p);
+static void jobqueue_push(thpool_* thpool_p, struct job* newjob_p);
+static struct job* jobqueue_pull(thpool_* thpool_p);
+static void jobqueue_destroy(thpool_* thpool_p);
+
+static void bsem_init(struct bsem *bsem_p, int value);
+static void bsem_reset(struct bsem *bsem_p);
+static void bsem_post(struct bsem *bsem_p);
+static void bsem_post_all(struct bsem *bsem_p);
+static void bsem_wait(struct bsem *bsem_p);
+
+
+
+
+
+/* ========================== THREADPOOL ============================ */
+
+
+/* Initialise thread pool */
+struct thpool_* thpool_init(int num_threads){
+
+ threads_on_hold = 0;
+ threads_keepalive = 1;
+
+ if ( num_threads < 0){
+ num_threads = 0;
+ }
+
+ /* Make new thread pool */
+ thpool_* thpool_p=NULL;
+ thpool_p = (struct thpool_*)calloc(1,sizeof(struct thpool_));
+ if (thpool_p==NULL){
+ fprintf(stderr, "thpool_init(): Could not allocate memory for thread pool\n");
+ exit(1);
+ }
+ pthread_mutex_init(&(thpool_p->thcount_lock), NULL);
+ thpool_p->num_threads_alive = 0;
+ thpool_p->num_threads_working = 0;
+
+ /* Initialise the job queue */
+ if (jobqueue_init(thpool_p)==-1){
+ fprintf(stderr, "thpool_init(): Could not allocate memory for job queue\n");
+ exit(1);
+ }
+
+ /* Make threads in pool */
+ thpool_p->threads = (struct thread**)calloc(num_threads,sizeof(struct thread));
+ if (thpool_p->threads==NULL){
+ fprintf(stderr, "thpool_init(): Could not allocate memory for threads\n");
+ exit(1);
+ }
+
+ /* Thread init */
+ int n;
+ for (n=0; n<num_threads; n++){
+ thread_init(thpool_p, &thpool_p->threads[n], n);
+ if (THPOOL_DEBUG)
+ printf("THPOOL_DEBUG: Created thread %d in pool \n", n);
+ }
+
+ /* Wait for threads to initialize */
+ while (thpool_p->num_threads_alive != num_threads) {}
+
+ return thpool_p;
+}
+
+
+/* Add work to the thread pool */
+int thpool_add_work(thpool_* thpool_p, void *(*function_p)(void*), void* arg_p){
+ job* newjob=NULL;
+
+ newjob=(struct job*)calloc(1,sizeof(struct job));
+ if (newjob==NULL){
+ fprintf(stderr, "thpool_add_work(): Could not allocate memory for new job\n");
+ return -1;
+ }
+
+ /* add function and argument */
+ newjob->function=function_p;
+ newjob->arg=arg_p;
+
+ /* add job to queue */
+ pthread_mutex_lock(&thpool_p->jobqueue_p->rwmutex);
+ jobqueue_push(thpool_p, newjob);
+ pthread_mutex_unlock(&thpool_p->jobqueue_p->rwmutex);
+
+ return 0;
+}
+
+
+/* Wait until all jobs have finished */
+void thpool_wait(thpool_* thpool_p){
+
+ /* Continuous polling */
+ double timeout = 1.0;
+ time_t start, end;
+ double tpassed = 0.0;
+ time (&start);
+ while (tpassed < timeout &&
+ (thpool_p->jobqueue_p->len || thpool_p->num_threads_working))
+ {
+ time (&end);
+ tpassed = difftime(end,start);
+ }
+
+ /* Exponential polling */
+ long init_nano = 1; /* MUST be above 0 */
+ long new_nano;
+ double multiplier = 1.01;
+ int max_secs = 20;
+
+ struct timespec polling_interval;
+ polling_interval.tv_sec = 0;
+ polling_interval.tv_nsec = init_nano;
+
+ while (thpool_p->jobqueue_p->len || thpool_p->num_threads_working)
+ {
+ nanosleep(&polling_interval, NULL);
+ if ( polling_interval.tv_sec < max_secs ){
+ new_nano = CEIL(polling_interval.tv_nsec * multiplier);
+ polling_interval.tv_nsec = new_nano % MAX_NANOSEC;
+ if ( new_nano > MAX_NANOSEC ) {
+ polling_interval.tv_sec ++;
+ }
+ }
+ else break;
+ }
+
+ /* Fall back to max polling */
+ while (thpool_p->jobqueue_p->len || thpool_p->num_threads_working){
+ sleep(max_secs);
+ }
+}
+
+
+/* Destroy the threadpool */
+void thpool_destroy(thpool_* thpool_p){
+
+ volatile int threads_total = thpool_p->num_threads_alive;
+
+ /* End each thread 's infinite loop */
+ threads_keepalive = 0;
+
+ /* Give one second to kill idle threads */
+ double TIMEOUT = 1.0;
+ time_t start, end;
+ double tpassed = 0.0;
+ time (&start);
+ while (tpassed < TIMEOUT && thpool_p->num_threads_alive){
+ bsem_post_all(thpool_p->jobqueue_p->has_jobs);
+ time (&end);
+ tpassed = difftime(end,start);
+ }
+
+ /* Poll remaining threads */
+ while (thpool_p->num_threads_alive){
+ bsem_post_all(thpool_p->jobqueue_p->has_jobs);
+ sleep(1);
+ }
+
+ /* Job queue cleanup */
+ jobqueue_destroy(thpool_p);
+ free(thpool_p->jobqueue_p);
+
+ /* Deallocs */
+ int n;
+ for (n=0; n < threads_total; n++){
+ thread_destroy(thpool_p->threads[n]);
+ }
+ free(thpool_p->threads);
+ free(thpool_p);
+}
+
+
+/* Pause all threads in threadpool */
+void thpool_pause(thpool_* thpool_p) {
+ int n;
+ for (n=0; n < thpool_p->num_threads_alive; n++){
+ pthread_kill(thpool_p->threads[n]->pthread, SIGUSR1);
+ }
+}
+
+
+/* Resume all threads in threadpool */
+void thpool_resume(thpool_* thpool_p) {
+ threads_on_hold = 0;
+}
+
+
+
+
+
+/* ============================ THREAD ============================== */
+
+
+/* Initialize a thread in the thread pool
+ *
+ * @param thread address to the pointer of the thread to be created
+ * @param id id to be given to the thread
+ *
+ */
+static void thread_init (thpool_* thpool_p, struct thread** thread_p, int id){
+
+ *thread_p = (struct thread*)calloc(1,sizeof(struct thread));
+ if (thread_p == NULL){
+ fprintf(stderr, "thpool_init(): Could not allocate memory for thread\n");
+ exit(1);
+ }
+
+ (*thread_p)->thpool_p = thpool_p;
+ (*thread_p)->id = id;
+
+ pthread_create(&(*thread_p)->pthread, NULL, (void *)thread_do, (*thread_p));
+ pthread_detach((*thread_p)->pthread);
+
+}
+
+
+/* Sets the calling thread on hold */
+static void thread_hold () {
+ threads_on_hold = 1;
+ while (threads_on_hold){
+ sleep(1);
+ }
+}
+
+
+/* What each thread is doing
+*
+* In principle this is an endless loop. The only time this loop gets interuppted is once
+* thpool_destroy() is invoked or the program exits.
+*
+* @param thread thread that will run this function
+* @return nothing
+*/
+static void* thread_do(struct thread* thread_p){
+
+ /* Assure all threads have been created before starting serving */
+ thpool_* thpool_p = thread_p->thpool_p;
+
+ /* Register signal handler */
+ struct sigaction act;
+ act.sa_handler = thread_hold;
+ if (sigaction(SIGUSR1, &act, NULL) == -1) {
+ fprintf(stderr, "thread_do(): cannot handle SIGUSR1");
+ }
+
+ /* Mark thread as alive (initialized) */
+ pthread_mutex_lock(&thpool_p->thcount_lock);
+ thpool_p->num_threads_alive += 1;
+ pthread_mutex_unlock(&thpool_p->thcount_lock);
+
+ while(threads_keepalive){
+
+ bsem_wait(thpool_p->jobqueue_p->has_jobs);
+
+ if (threads_keepalive){
+
+ pthread_mutex_lock(&thpool_p->thcount_lock);
+ thpool_p->num_threads_working++;
+ pthread_mutex_unlock(&thpool_p->thcount_lock);
+
+ /* Read job from queue and execute it */
+ void*(*func_buff)(void* arg);
+ void* arg_buff;
+ job* job_p;
+ pthread_mutex_lock(&thpool_p->jobqueue_p->rwmutex);
+ job_p = jobqueue_pull(thpool_p);
+ pthread_mutex_unlock(&thpool_p->jobqueue_p->rwmutex);
+ if (job_p) {
+ func_buff = job_p->function;
+ arg_buff = job_p->arg;
+ func_buff(arg_buff);
+ free(job_p);
+ }
+
+ pthread_mutex_lock(&thpool_p->thcount_lock);
+ thpool_p->num_threads_working--;
+ pthread_mutex_unlock(&thpool_p->thcount_lock);
+
+ }
+ }
+ pthread_mutex_lock(&thpool_p->thcount_lock);
+ thpool_p->num_threads_alive --;
+ pthread_mutex_unlock(&thpool_p->thcount_lock);
+
+ return NULL;
+}
+
+
+/* Frees a thread */
+static void thread_destroy (thread* thread_p){
+ free(thread_p);
+}
+
+
+
+
+
+/* ============================ JOB QUEUE =========================== */
+
+
+/* Initialize queue */
+static int jobqueue_init(thpool_* thpool_p){
+
+ thpool_p->jobqueue_p = (struct jobqueue*)calloc(1,sizeof(struct jobqueue));
+ pthread_mutex_init(&(thpool_p->jobqueue_p->rwmutex), NULL);
+ if (thpool_p->jobqueue_p == NULL){
+ return -1;
+ }
+
+ thpool_p->jobqueue_p->has_jobs = (struct bsem*)calloc(1,sizeof(struct bsem));
+ if (thpool_p->jobqueue_p->has_jobs == NULL){
+ return -1;
+ }
+ bsem_init(thpool_p->jobqueue_p->has_jobs, 0);
+
+ jobqueue_clear(thpool_p);
+ return 0;
+}
+
+
+/* Clear the queue */
+static void jobqueue_clear(thpool_* thpool_p){
+
+ while(thpool_p->jobqueue_p->len){
+ free(jobqueue_pull(thpool_p));
+ }
+
+ thpool_p->jobqueue_p->front = NULL;
+ thpool_p->jobqueue_p->rear = NULL;
+ bsem_reset(thpool_p->jobqueue_p->has_jobs);
+ thpool_p->jobqueue_p->len = 0;
+
+}
+
+
+/* Add (allocated) job to queue
+ *
+ * Notice: Caller MUST hold a mutex
+ */
+static void jobqueue_push(thpool_* thpool_p, struct job* newjob){
+
+ newjob->prev = NULL;
+
+ switch(thpool_p->jobqueue_p->len){
+
+ case 0: /* if no jobs in queue */
+ thpool_p->jobqueue_p->front = newjob;
+ thpool_p->jobqueue_p->rear = newjob;
+ break;
+
+ default: /* if jobs in queue */
+ thpool_p->jobqueue_p->rear->prev = newjob;
+ thpool_p->jobqueue_p->rear = newjob;
+
+ }
+ thpool_p->jobqueue_p->len++;
+
+ bsem_post(thpool_p->jobqueue_p->has_jobs);
+}
+
+
+/* Get first job from queue(removes it from queue)
+ *
+ * Notice: Caller MUST hold a mutex
+ */
+static struct job* jobqueue_pull(thpool_* thpool_p){
+
+ job* job_p;
+ job_p = thpool_p->jobqueue_p->front;
+
+ switch(thpool_p->jobqueue_p->len){
+
+ case 0: /* if no jobs in queue */
+ return NULL;
+
+ case 1: /* if one job in queue */
+ thpool_p->jobqueue_p->front = NULL;
+ thpool_p->jobqueue_p->rear = NULL;
+ break;
+
+ default: /* if >1 jobs in queue */
+ thpool_p->jobqueue_p->front = job_p->prev;
+
+ }
+ thpool_p->jobqueue_p->len--;
+
+ /* Make sure has_jobs has right value */
+ if (thpool_p->jobqueue_p->len > 0) {
+ bsem_post(thpool_p->jobqueue_p->has_jobs);
+ }
+
+ return job_p;
+}
+
+
+/* Free all queue resources back to the system */
+static void jobqueue_destroy(thpool_* thpool_p){
+ jobqueue_clear(thpool_p);
+ free(thpool_p->jobqueue_p->has_jobs);
+}
+
+
+
+
+
+/* ======================== SYNCHRONISATION ========================= */
+
+
+/* Init semaphore to 1 or 0 */
+static void bsem_init(bsem *bsem_p, int value) {
+ if (value < 0 || value > 1) {
+ fprintf(stderr, "bsem_init(): Binary semaphore can take only values 1 or 0");
+ exit(1);
+ }
+ pthread_mutex_init(&(bsem_p->mutex), NULL);
+ pthread_cond_init(&(bsem_p->cond), NULL);
+ bsem_p->v = value;
+}
+
+
+/* Reset semaphore to 0 */
+static void bsem_reset(bsem *bsem_p) {
+ bsem_init(bsem_p, 0);
+}
+
+
+/* Post to at least one thread */
+static void bsem_post(bsem *bsem_p) {
+ pthread_mutex_lock(&bsem_p->mutex);
+ bsem_p->v = 1;
+ pthread_cond_signal(&bsem_p->cond);
+ pthread_mutex_unlock(&bsem_p->mutex);
+}
+
+
+/* Post to all threads */
+static void bsem_post_all(bsem *bsem_p) {
+ pthread_mutex_lock(&bsem_p->mutex);
+ bsem_p->v = 1;
+ pthread_cond_broadcast(&bsem_p->cond);
+ pthread_mutex_unlock(&bsem_p->mutex);
+}
+
+
+/* Wait on semaphore until semaphore has value 0 */
+static void bsem_wait(bsem* bsem_p) {
+ pthread_mutex_lock(&bsem_p->mutex);
+ while (bsem_p->v != 1) {
+ pthread_cond_wait(&bsem_p->cond, &bsem_p->mutex);
+ }
+ bsem_p->v = 0;
+ pthread_mutex_unlock(&bsem_p->mutex);
+}
diff --git a/src/thpool.h b/src/thpool.h
new file mode 100644
index 0000000..6e600b2
--- /dev/null
+++ b/src/thpool.h
@@ -0,0 +1,166 @@
+/**********************************
+ * @author Johan Hanssen Seferidis
+ * License: MIT
+ *
+ **********************************/
+
+#ifndef _THPOOL_
+#define _THPOOL_
+
+
+#ifdef __cplusplus
+extern "C" {
+#endif
+
+/* =================================== API ======================================= */
+
+
+typedef struct thpool_* threadpool;
+
+
+/**
+ * @brief Initialize threadpool
+ *
+ * Initializes a threadpool. This function will not return untill all
+ * threads have initialized successfully.
+ *
+ * @example
+ *
+ * ..
+ * threadpool thpool; //First we declare a threadpool
+ * thpool = thpool_init(4); //then we initialize it to 4 threads
+ * ..
+ *
+ * @param num_threads number of threads to be created in the threadpool
+ * @return threadpool created threadpool on success,
+ * NULL on error
+ */
+threadpool thpool_init(int num_threads);
+
+
+/**
+ * @brief Add work to the job queue
+ *
+ * Takes an action and its argument and adds it to the threadpool's job queue.
+ * If you want to add to work a function with more than one arguments then
+ * a way to implement this is by passing a pointer to a structure.
+ *
+ * NOTICE: You have to cast both the function and argument to not get warnings.
+ *
+ * @example
+ *
+ * void print_num(int num){
+ * printf("%d\n", num);
+ * }
+ *
+ * int main() {
+ * ..
+ * int a = 10;
+ * thpool_add_work(thpool, (void*)print_num, (void*)a);
+ * ..
+ * }
+ *
+ * @param threadpool threadpool to which the work will be added
+ * @param function_p pointer to function to add as work
+ * @param arg_p pointer to an argument
+ * @return nothing
+ */
+int thpool_add_work(threadpool, void *(*function_p)(void*), void* arg_p);
+
+
+/**
+ * @brief Wait for all queued jobs to finish
+ *
+ * Will wait for all jobs - both queued and currently running to finish.
+ * Once the queue is empty and all work has completed, the calling thread
+ * (probably the main program) will continue.
+ *
+ * Smart polling is used in wait. The polling is initially 0 - meaning that
+ * there is virtually no polling at all. If after 1 seconds the threads
+ * haven't finished, the polling interval starts growing exponentially
+ * untill it reaches max_secs seconds. Then it jumps down to a maximum polling
+ * interval assuming that heavy processing is being used in the threadpool.
+ *
+ * @example
+ *
+ * ..
+ * threadpool thpool = thpool_init(4);
+ * ..
+ * // Add a bunch of work
+ * ..
+ * thpool_wait(thpool);
+ * puts("All added work has finished");
+ * ..
+ *
+ * @param threadpool the threadpool to wait for
+ * @return nothing
+ */
+void thpool_wait(threadpool);
+
+
+/**
+ * @brief Pauses all threads immediately
+ *
+ * The threads will be paused no matter if they are idle or working.
+ * The threads return to their previous states once thpool_resume
+ * is called.
+ *
+ * While the thread is being paused, new work can be added.
+ *
+ * @example
+ *
+ * threadpool thpool = thpool_init(4);
+ * thpool_pause(thpool);
+ * ..
+ * // Add a bunch of work
+ * ..
+ * thpool_resume(thpool); // Let the threads start their magic
+ *
+ * @param threadpool the threadpool where the threads should be paused
+ * @return nothing
+ */
+void thpool_pause(threadpool);
+
+
+/**
+ * @brief Unpauses all threads if they are paused
+ *
+ * @example
+ * ..
+ * thpool_pause(thpool);
+ * sleep(10); // Delay execution 10 seconds
+ * thpool_resume(thpool);
+ * ..
+ *
+ * @param threadpool the threadpool where the threads should be unpaused
+ * @return nothing
+ */
+void thpool_resume(threadpool);
+
+
+/**
+ * @brief Destroy the threadpool
+ *
+ * This will wait for the currently active threads to finish and then 'kill'
+ * the whole threadpool to free up memory.
+ *
+ * @example
+ * int main() {
+ * threadpool thpool1 = thpool_init(2);
+ * threadpool thpool2 = thpool_init(2);
+ * ..
+ * thpool_destroy(thpool1);
+ * ..
+ * return 0;
+ * }
+ *
+ * @param threadpool the threadpool to destroy
+ * @return nothing
+ */
+void thpool_destroy(threadpool);
+
+#ifdef __cplusplus
+}
+#endif
+
+#endif
diff --git a/src/timer.cpp b/src/timer.cpp
new file mode 100644
index 0000000..239a0aa
--- /dev/null
+++ b/src/timer.cpp
@@ -0,0 +1,109 @@
+#include <ctime>
+#include <iostream>
+#include <iomanip>
+#include "util.h"
+#include "timer.h"
+
+
+/***
+ * Return the total time that the timer has been in the "running"
+ * state since it was first "started" or last "restarted". For
+ * "short" time periods (less than an hour), the actual cpu time
+ * used is reported instead of the elapsed time.
+ */
+double Timer::elapsed_time()
+{
+ time_t now;
+ time(&now);
+ return difftime(now, start_time);
+}
+
+/***
+ * Return the total time that the timer has been in the "running"
+ * state since it was first "started" or last "restarted". For
+ * "short" time periods (less than an hour), the actual cpu time
+ * used is reported instead of the elapsed time.
+ * This function is the public version of elapsed_time()
+ */
+double Timer::get_elapsed_time()
+{
+ return elapsed_time();
+}
+
+/***
+ * Start a timer. If it is already running, let it continue running.
+ * Print an optional message.
+ */
+void Timer::start(const char* msg)
+{
+ // Print an optional message, something like "Starting timer t";
+ if (msg) VERBOSE(0, msg << std::endl);
+
+ // Return immediately if the timer is already running
+ if (running) return;
+
+ // Change timer status to running
+ running = true;
+
+ // Set the start time;
+ time(&start_time);
+}
+
+/***
+ * Turn the timer off and start it again from 0. Print an optional message.
+ */
+/*
+inline void Timer::restart(const char* msg)
+{
+ // Print an optional message, something like "Restarting timer t";
+ if (msg) VERBOSE(0, msg << std::endl;
+
+ // Set the timer status to running
+ running = true;
+
+ // Set the accumulated time to 0 and the start time to now
+ acc_time = 0;
+ start_clock = clock();
+ start_time = time(0);
+}
+*/
+
+/***
+ * Stop the timer and print an optional message.
+ */
+/*
+inline void Timer::stop(const char* msg)
+{
+ // Print an optional message, something like "Stopping timer t";
+ check(msg);
+
+ // Recalculate and store the total accumulated time up until now
+ if (running) acc_time += elapsed_time();
+
+ running = false;
+}
+*/
+/***
+ * Print out an optional message followed by the current timer timing.
+ */
+void Timer::check(const char* msg)
+{
+ // Print an optional message, something like "Checking timer t";
+ if (msg) VERBOSE(0, msg << " : ");
+
+ VERBOSE(0, "[" << (running ? elapsed_time() : 0) << "] seconds\n");
+}
+
+/***
+ * Allow timers to be printed to ostreams using the syntax 'os << t'
+ * for an ostream 'os' and a timer 't'. For example, "cout << t" will
+ * print out the total amount of time 't' has been "running".
+ */
+std::ostream& operator<<(std::ostream& os, Timer& t)
+{
+ //os << std::setprecision(2) << std::setiosflags(std::ios::fixed) << (t.running ? t.elapsed_time() : 0);
+ os << (t.running ? t.elapsed_time() : 0);
+ return os;
+}
+
+
diff --git a/src/timer.h b/src/timer.h
new file mode 100644
index 0000000..2e2fea5
--- /dev/null
+++ b/src/timer.h
@@ -0,0 +1,35 @@
+#ifndef TIMER_H
+#define TIMER_H
+
+#include <ctime>
+#include <iostream>
+#include <iomanip>
+#include "util.h"
+
+class Timer
+{
+ friend std::ostream& operator<<(std::ostream& os, Timer& t);
+
+private:
+ bool running;
+ time_t start_time;
+
+ //TODO in seconds?
+ double elapsed_time();
+
+public:
+ /***
+ * 'running' is initially false. A timer needs to be explicitly started
+ * using 'start' or 'restart'
+ */
+ Timer() : running(false), start_time(0) { }
+
+ void start(const char* msg = 0);
+// void restart(const char* msg = 0);
+// void stop(const char* msg = 0);
+ void check(const char* msg = 0);
+ double get_elapsed_time();
+
+};
+
+#endif // TIMER_H
diff --git a/src/tlm.cpp b/src/tlm.cpp
new file mode 100644
index 0000000..3daefb7
--- /dev/null
+++ b/src/tlm.cpp
@@ -0,0 +1,586 @@
+
+/******************************************************************************
+ IrstLM: IRST Language Model Toolkit
+ Copyright (C) 2006 Marcello Federico, ITC-irst Trento, Italy
+
+ This library is free software; you can redistribute it and/or
+ modify it under the terms of the GNU Lesser General Public
+ License as published by the Free Software Foundation; either
+ version 2.1 of the License, or (at your option) any later version.
+
+ This library is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+ Lesser General Public License for more details.
+
+ You should have received a copy of the GNU Lesser General Public
+ License along with this library; if not, write to the Free Software
+ Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+
+ ******************************************************************************/
+
+#include <iostream>
+#include <cmath>
+#include <math.h>
+#include "cmd.h"
+#include "mfstream.h"
+#include "mempool.h"
+#include "htable.h"
+#include "dictionary.h"
+#include "n_gram.h"
+#include "mempool.h"
+#include "ngramtable.h"
+#include "interplm.h"
+#include "normcache.h"
+#include "ngramcache.h"
+#include "mdiadapt.h"
+#include "shiftlm.h"
+#include "linearlm.h"
+#include "mixture.h"
+#include "lmtable.h"
+
+/********************************/
+using namespace std;
+using namespace irstlm;
+
+
+#define NGRAM 1
+#define SEQUENCE 2
+#define ADAPT 3
+#define TURN 4
+#define TEXT 5
+
+static Enum_T LmTypeEnum [] = {
+ { (char*)"ImprovedKneserNey", IMPROVED_KNESER_NEY },
+ { (char*)"ikn", IMPROVED_KNESER_NEY },
+ { (char*)"KneserNey", KNESER_NEY },
+ { (char*)"kn", KNESER_NEY },
+ { (char*)"ModifiedShiftBeta", MOD_SHIFT_BETA },
+ { (char*)"msb", MOD_SHIFT_BETA },
+ { (char*)"ImprovedShiftBeta", IMPROVED_SHIFT_BETA },
+ { (char*)"isb", IMPROVED_SHIFT_BETA },
+ { (char*)"InterpShiftBeta", SHIFT_BETA },
+ { (char*)"ShiftBeta", SHIFT_BETA },
+ { (char*)"sb", SHIFT_BETA },
+ { (char*)"InterpShiftOne", SHIFT_ONE },
+ { (char*)"ShiftOne", SHIFT_ONE },
+ { (char*)"s1", SHIFT_ONE },
+ { (char*)"LinearWittenBell", LINEAR_WB },
+ { (char*)"wb", LINEAR_WB },
+ { (char*)"StupidBackoff", LINEAR_STB },
+ { (char*)"stb", LINEAR_STB },
+ { (char*)"LinearGoodTuring", LINEAR_GT },
+ { (char*)"Mixture", MIXTURE },
+ { (char*)"mix", MIXTURE },
+ END_ENUM
+};
+
+static Enum_T InteractiveModeEnum [] = {
+ { (char*)"Ngram", NGRAM },
+ { (char*)"Sequence", SEQUENCE },
+ { (char*)"Adapt", ADAPT },
+ { (char*)"Turn", TURN },
+ { (char*)"Text", TEXT },
+ { (char*)"Yes", NGRAM },
+ END_ENUM
+};
+
+void print_help(int TypeFlag=0){
+ std::cerr << std::endl << "tlm - estimates a language model" << std::endl;
+ std::cerr << std::endl << "USAGE:" << std::endl;
+ std::cerr << " not yet available" << std::endl;
+ std::cerr << std::endl << "DESCRIPTION:" << std::endl;
+ std::cerr << " tlm is a tool for the estimation of language model" << std::endl;
+ std::cerr << std::endl << "OPTIONS:" << std::endl;
+ std::cerr << " -Help|-h this help" << std::endl;
+ std::cerr << std::endl;
+
+ FullPrintParams(TypeFlag, 0, 1, stderr);
+}
+
+void usage(const char *msg = 0)
+{
+ if (msg){
+ std::cerr << msg << std::endl;
+ }
+ else{
+ print_help();
+ }
+}
+
+int main(int argc, char **argv)
+{
+ char *dictfile=NULL;
+ char *trainfile=NULL;
+ char *testfile=NULL;
+ char *adaptfile=NULL;
+ char *slminfo=NULL;
+
+ char *imixpar=NULL;
+ char *omixpar=NULL;
+
+ char *BINfile=NULL;
+ char *ARPAfile=NULL;
+ bool SavePerLevel=true; //save-per-level or save-for-word
+
+ char *ASRfile=NULL;
+
+ char* scalefactorfile=NULL;
+
+ bool backoff=false; //back-off or interpolation
+ int lmtype=0;
+ int dub=10000000; //dictionary upper bound
+ int size=0; //lm size
+
+ int interactive=0;
+ int statistics=0;
+
+ int prunefreq=0;
+ bool prunesingletons=true;
+ bool prunetopsingletons=false;
+ char *prune_thr_str=NULL;
+
+ double beta=-1;
+
+ bool compsize=false;
+ bool checkpr=false;
+ double oovrate=0;
+ int max_caching_level=0;
+
+ char *outpr=NULL;
+
+ bool memmap = false; //write binary format with/without memory map, default is 0
+
+ int adaptlevel=0; //adaptation level
+ double adaptrate=1.0;
+ bool adaptoov=false; //do not increment the dictionary
+
+ bool help=false;
+
+ DeclareParams((char*)
+ "Back-off",CMDBOOLTYPE|CMDMSG, &backoff, "boolean flag for backoff LM (default is false, i.e. interpolated LM)",
+ "bo",CMDBOOLTYPE|CMDMSG, &backoff, "boolean flag for backoff LM (default is false, i.e. interpolated LM)",
+ "Dictionary", CMDSTRINGTYPE|CMDMSG, &dictfile, "dictionary to filter the LM (default is NULL)",
+ "d", CMDSTRINGTYPE|CMDMSG, &dictfile, "dictionary to filter the LM (default is NULL)",
+
+ "DictionaryUpperBound", CMDINTTYPE|CMDMSG, &dub, "dictionary upperbound to compute OOV word penalty: default 10^7",
+ "dub", CMDINTTYPE|CMDMSG, &dub, "dictionary upperbound to compute OOV word penalty: default 10^7",
+
+ "NgramSize", CMDSUBRANGETYPE|CMDMSG, &size, 1, MAX_NGRAM, "order of the LM",
+ "n", CMDSUBRANGETYPE|CMDMSG, &size, 1, MAX_NGRAM, "order of the LM",
+
+ "Ngram", CMDSTRINGTYPE|CMDMSG, &trainfile, "training file",
+ "TrainOn", CMDSTRINGTYPE|CMDMSG, &trainfile, "training file",
+ "tr", CMDSTRINGTYPE|CMDMSG, &trainfile, "training file",
+
+ "oASR", CMDSTRINGTYPE|CMDMSG, &ASRfile, "output file in ASR format",
+ "oasr", CMDSTRINGTYPE|CMDMSG, &ASRfile, "output file in ASR format",
+
+ "o", CMDSTRINGTYPE|CMDMSG, &ARPAfile, "output file in ARPA format",
+ "oARPA", CMDSTRINGTYPE|CMDMSG, &ARPAfile, "output file in ARPA format",
+ "oarpa", CMDSTRINGTYPE|CMDMSG, &ARPAfile, "output file in ARPA format",
+
+ "oBIN", CMDSTRINGTYPE|CMDMSG, &BINfile, "output file in binary format",
+ "obin", CMDSTRINGTYPE|CMDMSG, &BINfile, "output file in binary format",
+
+ "SavePerLevel",CMDBOOLTYPE|CMDMSG, &SavePerLevel, "saving type of the LM (true: per level (default), false: per word)",
+ "spl",CMDBOOLTYPE|CMDMSG, &SavePerLevel, "saving type of the LM (true: per level (default), false: per word)",
+
+ "TestOn", CMDSTRINGTYPE|CMDMSG, &testfile, "file for testing",
+ "te", CMDSTRINGTYPE|CMDMSG, &testfile, "file for testing",
+
+ "AdaptOn", CMDSTRINGTYPE|CMDMSG, &adaptfile, "file for adaptation",
+ "ad", CMDSTRINGTYPE|CMDMSG, &adaptfile, "file for adaptation",
+
+ "AdaptRate",CMDDOUBLETYPE|CMDMSG , &adaptrate, "adaptation rate",
+ "ar", CMDDOUBLETYPE|CMDMSG, &adaptrate, "adaptation rate",
+
+ "AdaptLevel", CMDSUBRANGETYPE|CMDMSG, &adaptlevel, 1 , MAX_NGRAM, "adaptation level",
+ "al",CMDSUBRANGETYPE|CMDMSG, &adaptlevel, 1, MAX_NGRAM, "adaptation level",
+
+ "AdaptOOV", CMDBOOLTYPE|CMDMSG, &adaptoov, "boolean flag for increasing the dictionary during adaptation (default is false)",
+ "ao", CMDBOOLTYPE|CMDMSG, &adaptoov, "boolean flag for increasing the dictionary during adaptation (default is false)",
+
+ "SaveScaleFactor", CMDSTRINGTYPE|CMDMSG, &scalefactorfile, "output file for the scale factors",
+ "ssf", CMDSTRINGTYPE|CMDMSG, &scalefactorfile, "output file for the scale factors",
+
+ "LanguageModelType",CMDENUMTYPE|CMDMSG, &lmtype, LmTypeEnum, "type of the LM",
+ "lm",CMDENUMTYPE|CMDMSG, &lmtype, LmTypeEnum, "type of the LM",
+
+ "Interactive",CMDENUMTYPE|CMDMSG, &interactive, InteractiveModeEnum, "type of interaction",
+ "i",CMDENUMTYPE|CMDMSG, &interactive, InteractiveModeEnum, "type of interaction",
+
+ "Statistics",CMDSUBRANGETYPE|CMDMSG, &statistics, 1, 3, "output statistics of the LM of increasing detail (default is 0)",
+ "s",CMDSUBRANGETYPE|CMDMSG, &statistics, 1, 3, "output statistics of the LM of increasing detail (default is 0)",
+
+ "PruneThresh",CMDSUBRANGETYPE|CMDMSG, &prunefreq, 0, 1000, "threshold for pruning (default is 0)",
+ "p",CMDSUBRANGETYPE|CMDMSG, &prunefreq, 0, 1000, "threshold for pruning (default is 0)",
+
+ "PruneSingletons",CMDBOOLTYPE|CMDMSG, &prunesingletons, "boolean flag for pruning of singletons (default is true)",
+ "ps",CMDBOOLTYPE|CMDMSG, &prunesingletons, "boolean flag for pruning of singletons (default is true)",
+
+ "PruneTopSingletons",CMDBOOLTYPE|CMDMSG, &prunetopsingletons, "boolean flag for pruning of singletons at the top level (default is false)",
+ "pts",CMDBOOLTYPE|CMDMSG, &prunetopsingletons, "boolean flag for pruning of singletons at the top level (default is false)",
+
+ "PruneFrequencyThreshold",CMDSTRINGTYPE|CMDMSG, &prune_thr_str, "pruning frequency threshold for each level; comma-separated list of values; (default is \"0,0,...,0\", for all levels)",
+ "pft",CMDSTRINGTYPE|CMDMSG, &prune_thr_str, "pruning frequency threshold for each level; comma-separated list of values; (default is \"0,0,...,0\", for all levels)",
+
+ "ComputeLMSize",CMDBOOLTYPE|CMDMSG, &compsize, "boolean flag for output the LM size (default is false)",
+ "sz",CMDBOOLTYPE|CMDMSG, &compsize, "boolean flag for output the LM size (default is false)",
+
+ "MaximumCachingLevel", CMDINTTYPE|CMDMSG , &max_caching_level, "maximum level for caches (default is: LM order - 1)",
+ "mcl", CMDINTTYPE|CMDMSG, &max_caching_level, "maximum level for caches (default is: LM order - 1)",
+
+ "MemoryMap", CMDBOOLTYPE|CMDMSG, &memmap, "use memory mapping for bianry saving (default is false)",
+ "memmap", CMDBOOLTYPE|CMDMSG, &memmap, "use memory mapping for bianry saving (default is false)",
+ "mm", CMDBOOLTYPE|CMDMSG, &memmap, "use memory mapping for bianry saving (default is false)",
+
+ "CheckProb",CMDBOOLTYPE|CMDMSG, &checkpr, "boolean flag for checking probability distribution during test (default is false)",
+ "cp",CMDBOOLTYPE|CMDMSG, &checkpr, "boolean flag for checking probability distribution during test (default is false)",
+
+ "OutProb",CMDSTRINGTYPE|CMDMSG, &outpr, "output file for debugging during test (default is \"/dev/null\")",
+ "op",CMDSTRINGTYPE|CMDMSG, &outpr, "output file for debugging during test (default is \"/dev/null\")",
+
+ "SubLMInfo", CMDSTRINGTYPE|CMDMSG, &slminfo, "configuration file for the mixture LM",
+ "slmi", CMDSTRINGTYPE|CMDMSG, &slminfo, "configuration file for the mixture LM",
+
+ "SaveMixParam", CMDSTRINGTYPE|CMDMSG, &omixpar, "output file for weights of the mixture LM",
+ "smp", CMDSTRINGTYPE|CMDMSG, &omixpar, "output file for weights of the mixture LM",
+
+ "LoadMixParam", CMDSTRINGTYPE|CMDMSG, &imixpar, "input file for weights of the mixture LM",
+ "lmp", CMDSTRINGTYPE|CMDMSG, &imixpar, "input file for weights of the mixture LM",
+
+ "SetOovRate", CMDDOUBLETYPE|CMDMSG, &oovrate, "rate for computing the OOV frequency (=oovrate*totfreq if oovrate>0) (default is 0)",
+ "or", CMDDOUBLETYPE|CMDMSG, &oovrate, "rate for computing the OOV frequency (=oovrate*totfreq if oovrate>0) (default is 0)",
+
+ "Beta", CMDDOUBLETYPE|CMDMSG, &beta, "beta value for Shift-Beta and Kneser-Ney LMs (default is -1, i.e. automatic estimation)",
+ "beta", CMDDOUBLETYPE|CMDMSG, &beta, "beta value for Shift-Beta and Kneser-Ney LMs (default is -1, i.e. automatic estimation)",
+
+ "Help", CMDBOOLTYPE|CMDMSG, &help, "print this help",
+ "h", CMDBOOLTYPE|CMDMSG, &help, "print this help",
+
+ (char *)NULL
+ );
+
+ if (argc == 1){
+ usage();
+ exit_error(IRSTLM_NO_ERROR);
+ }
+
+ GetParams(&argc, &argv, (char*) NULL);
+
+ if (help){
+ usage();
+ exit_error(IRSTLM_NO_ERROR);
+ }
+
+ if (!lmtype) {
+ exit_error(IRSTLM_ERROR_DATA,"The lm type (-lm) is not specified");
+ }
+
+ if (!trainfile && lmtype!=MIXTURE) {
+ exit_error(IRSTLM_ERROR_DATA,"The LM file (-tr) is not specified");
+ }
+
+ if (SavePerLevel == false && backoff == true){
+ cerr << "WARNING: Current implementation does not support the usage of backoff (-bo=true) mixture models (-lm=mix) combined with the per-word saving (-saveperllevel=false)." << endl;
+ cerr << "WARNING: The usage of backoff is disabled, i.e. -bo=no is forced" << endl;
+
+ backoff=false;
+ }
+
+ mdiadaptlm *lm=NULL;
+
+ switch (lmtype) {
+
+ case SHIFT_BETA:
+ if (beta==-1 || (beta<1.0 && beta>0)){
+ lm=new shiftbeta(trainfile,size,prunefreq,beta,(backoff?SHIFTBETA_B:SHIFTBETA_I));
+ } else {
+ exit_error(IRSTLM_ERROR_DATA,"ShiftBeta: beta must be >0 and <1");
+ }
+ break;
+
+ case KNESER_NEY:
+ if (size>1){
+ if (beta==-1 || (beta<1.0 && beta>0)){
+// lm=new kneserney(trainfile,size,prunefreq,beta,(backoff?KNESERNEY_B:KNESERNEY_I));
+ } else {
+ exit_error(IRSTLM_ERROR_DATA,"Kneser-Ney: beta must be >0 and <1");
+ }
+ } else {
+ exit_error(IRSTLM_ERROR_DATA,"Kneser-Ney requires size >1");
+ }
+ break;
+
+ case MOD_SHIFT_BETA:
+ cerr << "ModifiedShiftBeta (msb) is the old name for ImprovedKneserNey (ikn); this name is not supported anymore, but it is mapped into ImprovedKneserNey for back-compatibility";
+ case IMPROVED_KNESER_NEY:
+ if (size>1){
+ lm=new improvedkneserney(trainfile,size,prunefreq,(backoff?IMPROVEDKNESERNEY_B:IMPROVEDKNESERNEY_I));
+ } else {
+ exit_error(IRSTLM_ERROR_DATA,"Improved Kneser-Ney requires size >1");
+ }
+ break;
+
+ case IMPROVED_SHIFT_BETA:
+ lm=new improvedshiftbeta(trainfile,size,prunefreq,(backoff?IMPROVEDSHIFTBETA_B:IMPROVEDSHIFTBETA_I));
+ break;
+
+ case SHIFT_ONE:
+ lm=new shiftone(trainfile,size,prunefreq,(backoff?SIMPLE_B:SIMPLE_I));
+ break;
+
+ case LINEAR_STB:
+ lm=new linearstb(trainfile,size,prunefreq,IMPROVEDSHIFTBETA_B);
+ break;
+
+ case LINEAR_WB:
+ lm=new linearwb(trainfile,size,prunefreq,(backoff?IMPROVEDSHIFTBETA_B:IMPROVEDSHIFTBETA_I));
+ break;
+
+ case LINEAR_GT:
+ cerr << "This LM is no more supported\n";
+ break;
+
+ case MIXTURE:
+ //temporary check: so far unable to proper handle this flag in sub LMs
+ //no ngramtable is created
+ lm=new mixture(SavePerLevel,slminfo,size,prunefreq,imixpar,omixpar);
+ break;
+
+ default:
+ cerr << "not implemented yet\n";
+ return 1;
+ };
+
+ if (dub < lm->dict->size()){
+ cerr << "dub (" << dub << ") is not set or too small. dub is re-set to the dictionary size (" << lm->dict->size() << ")" << endl;
+ dub = lm->dict->size();
+ }
+
+ lm->dub(dub);
+
+ lm->create_caches(max_caching_level);
+
+ cerr << "eventually generate OOV code\n";
+ lm->dict->genoovcode();
+
+ if (oovrate) lm->dict->setoovrate(oovrate);
+
+ lm->save_per_level(SavePerLevel);
+
+ lm->train();
+
+ //it never occurs that both prunetopsingletons and prunesingletons are true
+ if (prunetopsingletons==true) { //keep most specific
+ lm->prunetopsingletons(true);
+ lm->prunesingletons(false);
+ } else {
+ lm->prunetopsingletons(false);
+ if (prunesingletons==true) {
+ lm->prunesingletons(true);
+ } else {
+ lm->prunesingletons(false);
+ }
+ }
+ if (prune_thr_str) lm->set_prune_ngram(prune_thr_str);
+
+ if (adaptoov) lm->dict->incflag(1);
+
+ if (adaptfile) lm->adapt(adaptfile,adaptlevel,adaptrate);
+
+ if (adaptoov) lm->dict->incflag(0);
+
+ if (scalefactorfile) lm->savescalefactor(scalefactorfile);
+
+ if (backoff) lm->compute_backoff();
+
+ if (size>lm->maxlevel()) {
+ exit_error(IRSTLM_ERROR_DATA,"lm size is too large");
+ }
+
+ if (!size) size=lm->maxlevel();
+
+ if (testfile) {
+ cerr << "TLM: test ...";
+ lm->test(testfile,size,backoff,checkpr,outpr);
+
+ if (adaptfile)
+ ((mdiadaptlm *)lm)->get_zetacache()->stat();
+
+ cerr << "\n";
+ };
+
+ if (compsize)
+ cout << "LM size " << (int)lm->netsize() << "\n";
+
+ if (interactive) {
+
+ ngram ng(lm->dict);
+ int nsize=0;
+
+ cout.setf(ios::scientific);
+
+ switch (interactive) {
+
+ case NGRAM:
+ cout << "> ";
+ while(cin >> ng) {
+ if (ng.wordp(size)) {
+ cout << ng << " p=" << (double)log(lm->prob(ng,size)) << "\n";
+ ng.size=0;
+ cout << "> ";
+ }
+ }
+ break;
+
+ case SEQUENCE: {
+ char c;
+ double p=0;
+ cout << "> ";
+
+ while(cin >> ng) {
+ nsize=ng.size<size?ng.size:size;
+ p=log(lm->prob(ng,nsize));
+ cout << ng << " p=" << p << "\n";
+
+ while((c=cin.get())==' ') {
+ cout << c;
+ }
+ cin.putback(c);
+ //cout << "-" << c << "-";
+ if (c=='\n') {
+ ng.size=0;
+ cout << "> ";
+ p=0;
+ }
+ }
+ }
+
+ break;
+
+ case TURN: {
+ int n=0;
+ double lp=0;
+ double oov=0;
+
+ while(cin >> ng) {
+
+ if (ng.size>0) {
+ nsize=ng.size<size?ng.size:size;
+ lp-=log(lm->prob(ng,nsize));
+ n++;
+ if (*ng.wordp(1) == lm->dict->oovcode())
+ oov++;
+ } else {
+ if (n>0) cout << n << " " << lp/(log(2.0) * n) << " " << oov/n << "\n";
+ n=0;
+ lp=0;
+ oov=0;
+ }
+ }
+
+ break;
+ }
+
+ case TEXT: {
+ int order;
+
+ int n=0;
+ double lp=0;
+ double oov=0;
+
+ while (!cin.eof()) {
+ cin >> order;
+ if (order>size)
+ cerr << "Warning: order > lm size\n";
+
+ order=order>size?size:order;
+
+ while (cin >> ng) {
+ if (ng.size>0) {
+ nsize=ng.size<order?ng.size:order;
+ lp-=log(lm->prob(ng,nsize));
+ n++;
+ if (*ng.wordp(1) == lm->dict->oovcode())
+ oov++;
+ } else {
+ if (n>0) cout << n << " " << lp/(log(2.0)*n) << " " << oov/n << "\n";
+ n=0;
+ lp=0;
+ oov=0;
+ if (ng.isym>0) break;
+ }
+ }
+ }
+ }
+ break;
+
+ case ADAPT: {
+
+ if (backoff) {
+ exit_error(IRSTLM_ERROR_DATA,"This modality is not supported with backoff LMs");
+ }
+
+ char afile[50],tfile[50];
+ while (!cin.eof()) {
+ cin >> afile >> tfile;
+ system("echo > .tlmlock");
+
+ cerr << "interactive adaptation: "
+ << afile << " " << tfile << "\n";
+
+ if (adaptoov) lm->dict->incflag(1);
+ lm->adapt(afile,adaptlevel,adaptrate);
+ if (adaptoov) lm->dict->incflag(0);
+ if (scalefactorfile) lm->savescalefactor(scalefactorfile);
+ if (ASRfile) lm->saveASR(ASRfile,backoff,dictfile);
+ if (ARPAfile) lm->saveARPA(ARPAfile,backoff,dictfile);
+ if (BINfile) lm->saveBIN(BINfile,backoff,dictfile,memmap);
+ lm->test(tfile,size,checkpr);
+ cout.flush();
+ system("rm .tlmlock");
+ }
+ }
+ break;
+ }
+
+ exit_error(IRSTLM_NO_ERROR);
+ }
+
+ if (ASRfile) {
+ cerr << "TLM: save lm (ASR)...";
+ lm->saveASR(ASRfile,backoff,dictfile);
+ cerr << "\n";
+ }
+
+ if (ARPAfile) {
+ cerr << "TLM: save lm (ARPA)...";
+ lm->saveARPA(ARPAfile,backoff,dictfile);
+ cerr << "\n";
+ }
+
+ if (BINfile) {
+ cerr << "TLM: save lm (binary)...";
+ lm->saveBIN(BINfile,backoff,dictfile,memmap);
+ cerr << "\n";
+ }
+
+ if (statistics) {
+ cerr << "TLM: lm stat ...";
+ lm->lmstat(statistics);
+ cerr << "\n";
+ }
+
+ // lm->cache_stat();
+
+ cerr << "TLM: deleting lm ...";
+ delete lm;
+ cerr << "\n";
+
+ exit_error(IRSTLM_NO_ERROR);
+}
+
+
+
diff --git a/src/util.cpp b/src/util.cpp
new file mode 100644
index 0000000..77b8972
--- /dev/null
+++ b/src/util.cpp
@@ -0,0 +1,369 @@
+// $Id: util.cpp 363 2010-02-22 15:02:45Z mfederico $
+/******************************************************************************
+ IrstLM: IRST Language Model Toolkit
+ Copyright (C) 2006 Marcello Federico, ITC-irst Trento, Italy
+
+ This library is free software; you can redistribute it and/or
+ modify it under the terms of the GNU Lesser General Public
+ License as published by the Free Software Foundation; either
+ version 2.1 of the License, or (at your option) any later version.
+
+ This library is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+ Lesser General Public License for more details.
+
+ You should have received a copy of the GNU Lesser General Public
+ License along with this library; if not, write to the Free Software
+ Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+
+ ******************************************************************************/
+
+#ifdef WIN32
+#include <windows.h>
+#include <string.h>
+#include <io.h>
+#else
+#include <cstring>
+#include <cstdlib>
+#include <iostream>
+#include <sstream>
+#include <sys/types.h>
+#include <sys/mman.h>
+#endif
+#include "gzfilebuf.h"
+#include "timer.h"
+#include "util.h"
+#include "n_gram.h"
+#include "mfstream.h"
+
+using namespace std;
+
+string gettempfolder()
+{
+#ifdef _WIN32
+ char *tmpPath = getenv("TMP");
+ string str(tmpPath);
+ if (str.substr(str.size() - 1, 1) != "\\")
+ str += "\\";
+ return str;
+#else
+ char *tmpPath = getenv("TMP");
+ if (!tmpPath || !*tmpPath)
+ return "/tmp/";
+ string str(tmpPath);
+ if (str.substr(str.size() - 1, 1) != "/")
+ str += "/";
+ return str;
+#endif
+}
+
+string createtempName()
+{
+ string tmpfolder = gettempfolder();
+#ifdef _WIN32
+ char buffer[BUFSIZ];
+ //To check whether the following function open the stream as well
+ //In this case it is mandatory to close it immediately
+ ::GetTempFileNameA(tmpfolder.c_str(), "", 0, buffer);
+#else
+ char buffer[tmpfolder.size() + 16];
+ strcpy(buffer, tmpfolder.c_str());
+ strcat(buffer, "dskbuff--XXXXXX");
+ int fd=mkstemp(buffer);
+ close(fd);
+#endif
+ return (string) buffer;
+}
+
+void createtempfile(mfstream &fileStream, string &filePath, std::ios_base::openmode flags)
+{
+ filePath = createtempName();
+ fileStream.open(filePath.c_str(), flags);
+}
+
+void removefile(const std::string &filePath)
+{
+#ifdef _WIN32
+ ::DeleteFileA(filePath.c_str());
+#else
+ if (remove(filePath.c_str()) != 0)
+ {
+ perror("Error deleting file" );
+ exit_error(IRSTLM_ERROR_IO);
+ }
+#endif
+}
+
+/* MemoryMap Management
+ Code kindly provided by Fabio Brugnara, ITC-irst Trento.
+ How to use it:
+ - call MMap with offset and required size (psgz):
+ pg->b = MMap(fd, rdwr,offset,pgsz,&g);
+ - correct returned pointer with the alignment gap and save the gap:
+ pg->b += pg->gap = g;
+ - when releasing mapped memory, subtract the gap from the pointer and add
+ the gap to the requested dimension
+ Munmap(pg->b-pg->gap, pgsz+pg->gap, 0);
+ */
+
+
+void *MMap(int fd, int access, off_t offset, size_t len, off_t *gap)
+{
+ void *p=NULL;
+
+#ifdef _WIN32
+ /*
+ int g=0;
+ // code for windows must be checked
+ HANDLE fh,
+ mh;
+
+ fh = (HANDLE)_get_osfhandle(fd);
+ if(offset) {
+ // bisogna accertarsi che l'offset abbia la granularita`
+ //corretta, MAI PROVATA!
+ SYSTEM_INFO si;
+
+ GetSystemInfo(&si);
+ g = *gap = offset % si.dwPageSize;
+ } else if(gap) {
+ *gap=0;
+ }
+ if(!(mh=CreateFileMapping(fh, NULL, PAGE_READWRITE, 0, len+g, NULL))) {
+ return 0;
+ }
+ p = (char*)MapViewOfFile(mh, FILE_MAP_ALL_ACCESS, 0,
+ offset-*gap, len+*gap);
+ CloseHandle(mh);
+ */
+
+#else
+ int pgsz,g=0;
+ if(offset) {
+ pgsz = sysconf(_SC_PAGESIZE);
+ g = *gap = offset%pgsz;
+ } else if(gap) {
+ *gap=0;
+ }
+ p = mmap((void*)0, len+g, access,
+ MAP_SHARED|MAP_FILE,
+ fd, offset-g);
+ if((long)p==-1L)
+ {
+ perror("mmap failed");
+ p=0;
+ }
+#endif
+ return p;
+}
+
+
+int Munmap(void *p,size_t len,int sync)
+{
+ int r=0;
+
+#ifdef _WIN32
+ /*
+ //code for windows must be checked
+ if(sync) FlushViewOfFile(p, len);
+ UnmapViewOfFile(p);
+ */
+#else
+ cerr << "len = " << len << endl;
+ cerr << "sync = " << sync << endl;
+ cerr << "running msync..." << endl;
+ if(sync) msync(p, len, MS_SYNC);
+ cerr << "done. Running munmap..." << endl;
+ if((r=munmap((void*)p, len)))
+ {
+ perror("munmap() failed");
+ }
+ cerr << "done" << endl;
+
+#endif
+ return r;
+}
+
+
+//global variable
+Timer g_timer;
+
+
+void ResetUserTime()
+{
+ g_timer.start();
+};
+
+void PrintUserTime(const std::string &message)
+{
+ g_timer.check(message.c_str());
+}
+
+double GetUserTime()
+{
+ return g_timer.get_elapsed_time();
+}
+
+
+void ShowProgress(long long current, long long target){
+
+ int frac=(current * 1000)/target;
+ if (!(frac % 10)) fprintf(stderr,"%02d\b\b",frac/10);
+
+}
+
+
+int parseWords(char *sentence, const char **words, int max)
+{
+ char *word;
+ int i = 0;
+
+ const char *const wordSeparators = " \t\r\n";
+
+ for (word = strtok(sentence, wordSeparators);
+ i < max && word != 0;
+ i++, word = strtok(0, wordSeparators)) {
+ words[i] = word;
+ }
+
+ if (i < max) {
+ words[i] = 0;
+ }
+
+ return i;
+}
+
+
+//Load a LM as a text file. LM could have been generated either with the
+//IRST LM toolkit or with the SRILM Toolkit. In the latter we are not
+//sure that n-grams are lexically ordered (according to the 1-grams).
+//However, we make the following assumption:
+//"all successors of any prefix are sorted and written in contiguous lines!"
+//This method also loads files processed with the quantization
+//tool: qlm
+
+int parseline(istream& inp, int Order,ngram& ng,float& prob,float& bow)
+{
+
+ const char* words[1+ LMTMAXLEV + 1 + 1];
+ int howmany;
+ char line[MAX_LINE];
+
+ inp.getline(line,MAX_LINE);
+ if (strlen(line)==MAX_LINE-1) {
+ std::stringstream ss_msg;
+ ss_msg << "parseline: input line exceed MAXLINE (" << MAX_LINE << ") chars " << line << "\n";
+
+ exit_error(IRSTLM_ERROR_DATA, ss_msg.str());
+ }
+
+ howmany = parseWords(line, words, Order + 3);
+
+ if (!(howmany == (Order+ 1) || howmany == (Order + 2))){
+ MY_ASSERT(howmany == (Order+ 1) || howmany == (Order + 2));
+ }
+
+ //read words
+ ng.size=0;
+ for (int i=1; i<=Order; i++)
+ ng.pushw(strcmp(words[i],"<unk>")?words[i]:ng.dict->OOV());
+
+ //read logprob/code and logbow/code
+ MY_ASSERT(sscanf(words[0],"%f",&prob));
+ if (howmany==(Order+2)){
+ MY_ASSERT(sscanf(words[Order+1],"%f",&bow));
+ }else{
+ bow=0.0; //this is log10prob=0 for implicit backoff
+ }
+ return 1;
+}
+
+void exit_error(int err, const std::string &msg){
+ if (msg != "") {
+ VERBOSE(0,msg+"\n";);
+ }
+ else{
+ switch(err){
+ case IRSTLM_NO_ERROR:
+ VERBOSE(0,"Everything OK\n");
+ break;
+ case IRSTLM_ERROR_GENERIC:
+ VERBOSE(0,"Generic error\n");
+ break;
+ case IRSTLM_ERROR_IO:
+ VERBOSE(0,"Input/Output error\n");
+ break;
+ case IRSTLM_ERROR_MEMORY:
+ VERBOSE(0,"Allocation memory error\n");
+ break;
+ case IRSTLM_ERROR_DATA:
+ VERBOSE(0,"Data format error\n");
+ break;
+ case IRSTLM_ERROR_MODEL:
+ VERBOSE(0,"Model computation error\n");
+ break;
+ default:
+ VERBOSE(0,"Undefined error\n");
+ break;
+ }
+ }
+ exit(err);
+};
+
+/*
+#ifdef MY_ASSERT_FLAG
+#if MY_ASSERT_FLAG>0
+#undef MY_ASSERT(x)
+#define MY_ASSERT(x) do { assert(x); } while (0)
+#else
+#define MY_ASSERT(x) { UNUSED(x); }
+#endif
+#else
+#define MY_ASSERT(x) { UNUSED(x); }
+#endif
+*/
+
+/** assert macros e functions**/
+#ifdef MY_ASSERT_FLAG
+#if MY_ASSERT_FLAG==0
+#undef MY_ASSERT_FLAG
+#endif
+#endif
+
+#ifdef MY_ASSERT_FLAG
+void MY_ASSERT(bool x) { assert(x); }
+#else
+void MY_ASSERT(bool x) { UNUSED(x); }
+#endif
+
+
+/** trace macros and functions**/
+/** verbose macros and functions**/
+
+#ifdef TRACE_LEVEL
+//int tracelevel=TRACE_LEVEL;
+const int tracelevel=TRACE_LEVEL;
+#else
+//int tracelevel=0;
+const int tracelevel=0;
+#endif
+
+
+namespace irstlm {
+ void* reallocf(void *ptr, size_t size){
+ void *p=realloc(ptr,size);
+
+ if (p)
+ {
+ return p;
+ }
+ else
+ {
+ free(ptr);
+ return NULL;
+ }
+ }
+
+}
+
diff --git a/src/util.h b/src/util.h
new file mode 100644
index 0000000..4d85170
--- /dev/null
+++ b/src/util.h
@@ -0,0 +1,97 @@
+// $Id: util.h 363 2010-02-22 15:02:45Z mfederico $
+
+#ifndef IRSTLM_UTIL_H
+#define IRSTLM_UTIL_H
+
+
+#include <string>
+#include <iostream>
+#include <fstream>
+#include <assert.h>
+
+using namespace std;
+
+#define MAX(a,b) (((a)>(b))?(a):(b))
+#define MIN(a,b) (((a)<(b))?(a):(b))
+
+//random values between -1 and +1
+#define MY_RAND (((float)random()/RAND_MAX)* 2.0 - 1.0)
+
+#define UNUSED(x) { (void) x; }
+
+#define LMTMAXLEV 20
+#define MAX_LINE 100000
+
+//0.000001 = 10^(-6)
+//0.000000000001 = 10^(-12)
+//1.000001 = 1+10^(-6)
+//1.000000000001 = 1+10^(-12)
+//0.999999 = 1-10^(-6)
+//0.999999999999 = 1-10^(-12)
+#define LOWER_SINGLE_PRECISION_OF_0 -0.000001
+#define UPPER_SINGLE_PRECISION_OF_0 0.000001
+#define LOWER_DOUBLE_PRECISION_OF_0 -0.000000000001
+#define UPPER_DOUBLE_PRECISION_OF_0 0.000000000001
+#define UPPER_SINGLE_PRECISION_OF_1 1.000001
+#define LOWER_SINGLE_PRECISION_OF_1 0.999999
+#define UPPER_DOUBLE_PRECISION_OF_1 1.000000000001
+#define LOWER_DOUBLE_PRECISION_OF_1 0.999999999999
+
+#define IRSTLM_NO_ERROR 0
+#define IRSTLM_ERROR_GENERIC 1
+#define IRSTLM_ERROR_IO 2
+#define IRSTLM_ERROR_MEMORY 3
+#define IRSTLM_ERROR_DATA 4
+#define IRSTLM_ERROR_MODEL 5
+
+#define BUCKET 10000
+#define SSEED 50
+
+class ngram;
+class mfstream;
+
+std::string gettempfolder();
+std::string createtempName();
+void createtempfile(mfstream &fileStream, std::string &filePath, std::ios_base::openmode flags);
+
+void removefile(const std::string &filePath);
+
+void *MMap(int fd, int access, off_t offset, size_t len, off_t *gap);
+int Munmap(void *p,size_t len,int sync);
+
+
+// A couple of utilities to measure access time
+void ResetUserTime();
+void PrintUserTime(const std::string &message);
+double GetUserTime();
+
+void ShowProgress(long long current,long long total);
+
+int parseWords(char *, const char **, int);
+int parseline(istream& inp, int Order,ngram& ng,float& prob,float& bow);
+
+void exit_error(int err, const std::string &msg="");
+
+namespace irstlm{
+ void* reallocf(void *ptr, size_t size);
+}
+
+//extern int tracelevel;
+extern const int tracelevel;
+
+#define TRACE_ERR(str) { std::cerr << str; }
+#define VERBOSE(level,str) { if (tracelevel > level) { TRACE_ERR("DEBUG_LEVEL:" << level << "/" << tracelevel << " "); TRACE_ERR(str); } }
+#define IFVERBOSE(level) if (tracelevel > level)
+
+/*
+#define _DEBUG_LEVEL TRACE_LEVEL
+
+#define TRACE_ERR(str) { std::cerr << str; }
+#define VERBOSE(level,str) { if (_DEBUG_LEVEL > level) { TRACE_ERR("DEBUG_LEVEL:" <<_DEBUG_LEVEL << " "); TRACE_ERR(str); } }
+#define IFVERBOSE(level) if (_DEBUG_LEVEL > level)
+*/
+
+void MY_ASSERT(bool x);
+
+#endif
+
diff --git a/src/verify-caching.cpp b/src/verify-caching.cpp
new file mode 100644
index 0000000..e1d3e95
--- /dev/null
+++ b/src/verify-caching.cpp
@@ -0,0 +1,91 @@
+// $Id: compile-lm.cpp 3677 2010-10-13 09:06:51Z bertoldi $
+
+/******************************************************************************
+ IrstLM: IRST Language Model Toolkit, compile LM
+ Copyright (C) 2006 Marcello Federico, ITC-irst Trento, Italy
+
+ This library is free software; you can redistribute it and/or
+ modify it under the terms of the GNU Lesser General Public
+ License as published by the Free Software Foundation; either
+ version 2.1 of the License, or (at your option) any later version.
+
+ This library is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
+ Lesser General Public License for more details.
+
+ You should have received a copy of the GNU Lesser General Public
+ License along with this library; if not, write to the Free Software
+ Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+
+ ******************************************************************************/
+
+
+#include <iostream>
+#include <string>
+#include <stdlib.h>
+#include "cmd.h"
+#include "util.h"
+#include "mdiadapt.h"
+#include "lmContainer.h"
+
+/********************************/
+using namespace std;
+using namespace irstlm;
+
+void print_help(int TypeFlag=0){
+ std::cerr << std::endl << "verify_caching - verify whether caching is enabled or disabled" << std::endl;
+ std::cerr << std::endl << "USAGE:" << std::endl;
+ std::cerr << " verify_caching" << std::endl;
+ std::cerr << std::endl << "DESCRIPTION:" << std::endl;
+ std::cerr << std::endl << "OPTIONS:" << std::endl;
+
+ FullPrintParams(TypeFlag, 0, 1, stderr);
+}
+
+void usage(const char *msg = 0)
+{
+ if (msg) {
+ std::cerr << msg << std::endl;
+ }
+ if (!msg){
+ print_help();
+ }
+}
+
+int main(int argc, char **argv)
+{
+ bool help=false;
+
+ DeclareParams((char*)
+
+ "Help", CMDBOOLTYPE|CMDMSG, &help, "print this help",
+ "h", CMDBOOLTYPE|CMDMSG, &help, "print this help",
+
+ (char *)NULL
+ );
+
+ if (argc > 1){
+ usage();
+ exit_error(IRSTLM_NO_ERROR);
+ }
+
+ GetParams(&argc, &argv, (char*) NULL);
+
+ if (help){
+ usage();
+ exit_error(IRSTLM_NO_ERROR);
+ }
+
+ if (lmContainer::is_cache_enabled()){
+ std::cout << " caching is ENABLED" << std::endl;
+ }else{
+ std::cout << " caching is DISABLED" << std::endl;
+ }
+
+ if (mdiadaptlm::is_train_cache_enabled()){
+ std::cout << " train-caching is ENABLED" << std::endl;
+ }else{
+ std::cout << " train-caching is DISABLED" << std::endl;
+ }
+}
--
Alioth's /usr/local/bin/git-commit-notice on /srv/git.debian.org/git/debian-science/packages/irstlm.git
More information about the debian-science-commits
mailing list