diff --git a/.gitignore b/.gitignore new file mode 100644 index 000000000..07145db9f --- /dev/null +++ b/.gitignore @@ -0,0 +1,244 @@ +example_extff/libff_example.so* +Testing/ +*/Testing/ +training/Testing/ +utils/Testing/ +CTestTestfile.cmake +cmake_install.cmake +CMakeCache.txt +CMakeFiles +utils/dedup_corpus +klm/lm/builder/dump_counts +klm/util/cat_compressed +example_extff/libff_example.1.0.0.dylib +example_extff/libff_example.1.dylib +example_extff/libff_example.dylib +example_extff/ff_example.lo +example_extff/libff_example.la +mteval/meteor_jar.cc +training/utils/grammar_convert +*.a +*.trs +*.aux +*.bbl +*.blg +*.dvi +*.idx +*.log +*.o +*.pdf +*.ps +*.pyc +*.so +*.toc +*swp +*~ +.* +./cdec/ +Makefile +Makefile.in +aclocal.m4 +autom4te.cache/ +config.guess +config.h +config.h.in +config.h.in~ +config.log +config.status +config.sub +configure +decoder/Makefile +decoder/Makefile.in +decoder/bin/ +decoder/cdec +decoder/dict_test +decoder/sv_test +decoder/ff_test +decoder/grammar_test +decoder/hg_test +decoder/logval_test +decoder/parser_test +decoder/rule_lexer.cc +decoder/small_vector_test +decoder/t2s_test +decoder/trule_test +decoder/weights_test +depcomp +dist +dpmert/Makefile +dpmert/Makefile.in +dpmert/fast_score +dpmert/lo_test +dpmert/mr_dpmert_generate_mapper_input +dpmert/mr_dpmert_map +dpmert/mr_dpmert_reduce +dpmert/scorer_test +dpmert/sentclient +dpmert/sentserver +dpmert/union_forests +dtrain/dtrain +extools/build_lexical_translation +extools/extractor +extools/extractor_monolingual +extools/featurize_grammar +extools/filter_grammar +extools/filter_score_grammar +extools/mr_stripe_rule_reduce +extools/score_grammar +extools/sg_lexer.cc +extractor/*_test +extractor/compile +extractor/extract +extractor/run_extractor +extractor/sacompile +gi/clda/src/clda +gi/markov_al/ml +gi/pf/align-lexonly +gi/pf/align-lexonly-pyp +gi/pf/align-tl +gi/pf/bayes_lattice_score +gi/pf/brat +gi/pf/cbgi +gi/pf/condnaive +gi/pf/dpnaive +gi/pf/itg +gi/pf/learn_cfg +gi/pf/nuisance_test +gi/pf/pf_test +gi/pf/pfbrat +gi/pf/pfdist +gi/pf/pfnaive +gi/pf/pyp_lm +gi/posterior-regularisation/prjava/build/ +gi/posterior-regularisation/prjava/lib/*.jar +gi/posterior-regularisation/prjava/lib/prjava-20100713.jar +gi/posterior-regularisation/prjava/lib/prjava-20100715.jar +gi/posterior-regularisation/prjava/prjava.jar +gi/pyp-topics/src/contexts_lexer.cc +gi/pyp-topics/src/pyp-contexts-train +gi/pyp-topics/src/pyp-topics-train +install-sh +jam-files/bjam +jam-files/engine/bin.* +jam-files/engine/bootstrap/ +klm/lm/bin/ +klm/lm/builder/builder +klm/lm/builder/lmplz +klm/lm/build_binary +klm/lm/ngram_query +klm/lm/query +klm/util/bin/ +libtool +ltmain.sh +m4/libtool.m4 +m4/ltoptions.m4 +m4/ltsugar.m4 +m4/ltversion.m4 +m4/lt~obsolete.m4 +minrisk/minrisk_optimize +mira/kbest_mira +missing +mteval/bin/ +mteval/fast_score +mteval/mbr_kbest +mteval/scorer_test +phrasinator/gibbs_train_plm +phrasinator/gibbs_train_plm_notables +previous.sh +pro-train/mr_pro_map +pro-train/mr_pro_reduce +python/build +python/setup.py +rampion/rampion_cccp +rst_parser/mst_train +rst_parser/random_tree +rst_parser/rst_parse +rst_parser/rst_train +sa-extract/calignment.c +sa-extract/cdat.c +sa-extract/cfloatlist.c +sa-extract/cintlist.c +sa-extract/clex.c +sa-extract/cstrmap.c +sa-extract/csuf.c +sa-extract/cveb.c +sa-extract/lcp.c +sa-extract/precomputation.c +sa-extract/rule.c +sa-extract/rulefactory.c +sa-extract/sym.c +stamp-h1 +tests/system_tests/hmm/foo.src +training/Makefile +training/Makefile.in +training/atools +training/augment_grammar +training/cllh_filter_grammar +training/collapse_weights +training/grammar_convert +training/lbfgs_test +training/lbl_model +training/liblbfgs/bin/ +training/liblbfgs/ll_test +training/model1 +training/mpi_batch_optimize +training/mpi_adagrad_optimize +training/mpi_compute_cllh +training/mpi_em_optimize +training/mpi_extract_features +training/mpi_extract_reachable +training/mpi_flex_optimize +training/mpi_online_optimize +training/mr_em_adapted_reduce +training/mr_em_map_adapter +training/mr_optimize_reduce +training/mr_reduce_to_weights +training/optimize_test +training/plftools +training/test_ngram +training/const_reorder/argument_reorder_model_trainer +training/const_reorder/const_reorder_model_trainer +utils/atools +utils/bin/ +utils/crp_test +utils/dict_test +utils/logval_test +utils/m_test +utils/mfcr_test +utils/phmt +utils/reconstruct_weights +utils/small_vector_test +utils/sv_test +utils/ts +utils/weights_test +training/crf/mpi_adagrad_optimize +training/crf/mpi_batch_optimize +training/crf/mpi_baum_welch +training/crf/mpi_compute_cllh +training/crf/mpi_extract_features +training/crf/mpi_extract_reachable +training/crf/mpi_flex_optimize +training/crf/mpi_online_optimize +training/dpmert/lo_test +training/dpmert/mr_dpmert_generate_mapper_input +training/dpmert/mr_dpmert_map +training/dpmert/mr_dpmert_reduce +training/dpmert/sentclient +training/dpmert/sentserver +training/dtrain/dtrain +training/latent_svm/latent_svm +training/minrisk/minrisk_optimize +training/mira/ada_opt_sm +training/mira/kbest_mira +training/mira/kbest_cut_mira +training/pro/mr_pro_map +training/pro/mr_pro_reduce +training/rampion/rampion_cccp +training/utils/lbfgs_test +training/utils/optimize_test +training/utils/sentclient +training/utils/sentserver +utils/stringlib_test +word-aligner/binderiv +word-aligner/fast_align +test-driver diff --git a/.travis.yml b/.travis.yml new file mode 100644 index 000000000..1f0f2eeef --- /dev/null +++ b/.travis.yml @@ -0,0 +1,23 @@ +language: python +python: + - "2.7" +before_script: + - sudo apt-get install libboost-filesystem1.48-dev + - sudo apt-get install libboost-program-options1.48-dev + - sudo apt-get install libboost-serialization1.48-dev + - sudo apt-get install libboost-regex1.48-dev + - sudo apt-get install libboost-test1.48-dev + - sudo apt-get install libboost-system1.48-dev + - sudo apt-get install libboost-thread1.48-dev + - sudo apt-get install flex + - autoreconf -ifv + - ./configure +script: + - make + - cd python + - python setup.py install + - cd .. +after_script: + - make check + - ./tests/run-system-tests.pl + - nosetests python/tests diff --git a/BUILDING b/BUILDING new file mode 100644 index 000000000..055c6f821 --- /dev/null +++ b/BUILDING @@ -0,0 +1,32 @@ +To build cdec, you'll need: + + * boost headers & boost program_options (you may need to install a package + like libboost-dev) + + +Instructions for building +----------------------------------- + + 1) Create a build directory and generate Makefiles using CMake + + mkdir build + cd build + cmake .. + + If the cmake command completes successfully, you can proceed. If you have + libraries (such as Boost) installed in nonstandard locations, you may need + to run cmake with special options like -DBOOST_ROOT=/path/to/boost. + + 2) Build + + make -j 2 + + 3) Test + make test + ./tests/run-system-tests.pl + + Everything should pass. + + + 4) Enjoy! + diff --git a/CMakeLists.txt b/CMakeLists.txt new file mode 100644 index 000000000..06d820613 --- /dev/null +++ b/CMakeLists.txt @@ -0,0 +1,68 @@ +cmake_minimum_required(VERSION 2.8) +project(cdec) + +add_definitions(-DKENLM_MAX_ORDER=6 -DHAVE_CONFIG_H) +set(CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR}/cmake) +set(CMAKE_CXX_FLAGS "-Wall -std=c++11 -O3") +set(METEOR_JAR "" CACHE FILEPATH "Path to meteor.jar") + +enable_testing() +include_directories(${CMAKE_CURRENT_SOURCE_DIR}) +include_directories(${CMAKE_CURRENT_BINARY_DIR}) + +# core packages +find_package(LibDL REQUIRED) +find_package(Boost COMPONENTS regex filesystem serialization program_options unit_test_framework system thread REQUIRED) +include_directories(${Boost_INCLUDE_DIR}) + +# eigen, used in some modeling extensions +find_package(Eigen3) +if(EIGEN3_FOUND) + include_directories(${EIGEN3_INCLUDE_DIR}) + set(HAVE_EIGEN 1) +endif(EIGEN3_FOUND) + +# compression packages (primarily used by KenLM) +find_package(ZLIB REQUIRED) +if(ZLIB_FOUND) + set(HAVE_ZLIB 1) +endif(ZLIB_FOUND) +find_package(BZip2) +if(BZIP2_FOUND) + set(HAVE_BZLIB 1) +endif(BZIP2_FOUND) +find_package(LibLZMA) +if(LIBLZMA_FOUND) + set(HAVE_XZLIB 1) +endif(LIBLZMA_FOUND) + +# for pycdec +find_package(PythonInterp 2.7 REQUIRED) + +# generate config.h +configure_file(${CMAKE_CURRENT_SOURCE_DIR}/config.h.cmake ${CMAKE_CURRENT_BINARY_DIR}/config.h) + +add_subdirectory(utils) +add_subdirectory(klm/util/double-conversion) +add_subdirectory(klm/util) +add_subdirectory(klm/util/stream) +add_subdirectory(klm/lm) +add_subdirectory(klm/lm/builder) +add_subdirectory(klm/search) +add_subdirectory(mteval) +add_subdirectory(decoder) +add_subdirectory(training) +add_subdirectory(word-aligner) +add_subdirectory(extractor) +add_subdirectory(example_extff) + +set(CPACK_PACKAGE_VERSION_MAJOR "2015") +set(CPACK_PACKAGE_VERSION_MINOR "04") +set(CPACK_PACKAGE_VERSION_PATCH "26") +set(CPACK_SOURCE_GENERATOR "TBZ2") +set(CPACK_SOURCE_PACKAGE_FILE_NAME + "${CMAKE_PROJECT_NAME}-${CPACK_PACKAGE_VERSION_MAJOR}.${CPACK_PACKAGE_VERSION_MINOR}.${CPACK_PACKAGE_VERSION_PATCH}") +set(CPACK_SOURCE_IGNORE_FILES + "/.git/;/.gitignore;/Testing/;/build/;/.bzr/;~$;/CMakeCache.txt;/CMakeFiles/;${CPACK_SOURCE_IGNORE_FILES}") +include(CPack) + diff --git a/LICENSE.txt b/LICENSE.txt new file mode 100644 index 000000000..a390938bc --- /dev/null +++ b/LICENSE.txt @@ -0,0 +1,213 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + +---------------------------------------------- + +L-BFGS CODE FROM COMPUTATIONAL CRYSTALLOGRAPHY TOOLBOX (CCTBX) + +This package includes source code (training/lbfgs.h) based on source +code distributed as part of the Compational Crystallography Toolbox +(CCTBX), which has separate copyright notices and license terms. Use of +this source code is subject to the terms and conditions of the license +contained in the file LICENSE.cctbx . + diff --git a/README.cmake b/README.cmake new file mode 100644 index 000000000..4653bee6c --- /dev/null +++ b/README.cmake @@ -0,0 +1,3 @@ + + cmake -G 'Unix Makefiles' -DMETEOR_JAR=/Users/cdyer/software/meteor-1.5/meteor-1.5.jar + diff --git a/README.md b/README.md new file mode 100644 index 000000000..25190b33e --- /dev/null +++ b/README.md @@ -0,0 +1,35 @@ +`cdec` is a research platform for machine translation and similar structured prediction problems. + +[![Build Status](https://travis-ci.org/redpony/cdec.svg?branch=master)](https://travis-ci.org/redpony/cdec) + +## System requirements + +- A Linux or Mac OS X system +- A C++ compiler implementing at least the [C++-11 standard](http://www.stroustrup.com/C++11FAQ.html) + - Some systems may have compilers that predate C++-11 support. + - You may need to build your own C++ compiler or upgrade your operating system's. +- [Boost C++ libraries (version 1.44 or later)](http://www.boost.org/) + - If you build your own boost, you _must install it_ using `bjam install` (to install it into a customized location use `--prefix=/path/to/target`). +- [GNU Flex](http://flex.sourceforge.net/) +- [cmake](http://www.cmake.org/) - (NEW) + +## Building the software + +Build instructions: + + mkdir build + cd build + cmake .. + make -j4 + make test + ./tests/run-system-tests.pl + +## Further information + +[For more information, refer to the `cdec` documentation](http://www.cdec-decoder.org) + +## Citation + +If you make use of cdec, please cite: + +C. Dyer, A. Lopez, J. Ganitkevitch, J. Weese, F. Ture, P. Blunsom, H. Setiawan, V. Eidelman, and P. Resnik. cdec: A Decoder, Alignment, and Learning Framework for Finite-State and Context-Free Translation Models. In *Proceedings of ACL*, July, 2010. [[bibtex](http://www.cdec-decoder.org/cdec.bibtex.txt)] [[pdf](http://www.aclweb.org/anthology/P/P10/P10-4002.pdf)] diff --git a/THREADS.txt b/THREADS.txt new file mode 100644 index 000000000..4dba2403a --- /dev/null +++ b/THREADS.txt @@ -0,0 +1,5 @@ +The cdec decoder is not, in general, thread safe. There are system components +that make use of multi-threading, but the decoder may not be used from multiple +threads. If you wish to decode in parallel, independent decoder processes +must be run. + diff --git a/cmake/FindEigen3.cmake b/cmake/FindEigen3.cmake new file mode 100644 index 000000000..9c546a05d --- /dev/null +++ b/cmake/FindEigen3.cmake @@ -0,0 +1,81 @@ +# - Try to find Eigen3 lib +# +# This module supports requiring a minimum version, e.g. you can do +# find_package(Eigen3 3.1.2) +# to require version 3.1.2 or newer of Eigen3. +# +# Once done this will define +# +# EIGEN3_FOUND - system has eigen lib with correct version +# EIGEN3_INCLUDE_DIR - the eigen include directory +# EIGEN3_VERSION - eigen version + +# Copyright (c) 2006, 2007 Montel Laurent, +# Copyright (c) 2008, 2009 Gael Guennebaud, +# Copyright (c) 2009 Benoit Jacob +# Redistribution and use is allowed according to the terms of the 2-clause BSD license. + +if(NOT Eigen3_FIND_VERSION) + if(NOT Eigen3_FIND_VERSION_MAJOR) + set(Eigen3_FIND_VERSION_MAJOR 2) + endif(NOT Eigen3_FIND_VERSION_MAJOR) + if(NOT Eigen3_FIND_VERSION_MINOR) + set(Eigen3_FIND_VERSION_MINOR 91) + endif(NOT Eigen3_FIND_VERSION_MINOR) + if(NOT Eigen3_FIND_VERSION_PATCH) + set(Eigen3_FIND_VERSION_PATCH 0) + endif(NOT Eigen3_FIND_VERSION_PATCH) + + set(Eigen3_FIND_VERSION "${Eigen3_FIND_VERSION_MAJOR}.${Eigen3_FIND_VERSION_MINOR}.${Eigen3_FIND_VERSION_PATCH}") +endif(NOT Eigen3_FIND_VERSION) + +macro(_eigen3_check_version) + file(READ "${EIGEN3_INCLUDE_DIR}/Eigen/src/Core/util/Macros.h" _eigen3_version_header) + + string(REGEX MATCH "define[ \t]+EIGEN_WORLD_VERSION[ \t]+([0-9]+)" _eigen3_world_version_match "${_eigen3_version_header}") + set(EIGEN3_WORLD_VERSION "${CMAKE_MATCH_1}") + string(REGEX MATCH "define[ \t]+EIGEN_MAJOR_VERSION[ \t]+([0-9]+)" _eigen3_major_version_match "${_eigen3_version_header}") + set(EIGEN3_MAJOR_VERSION "${CMAKE_MATCH_1}") + string(REGEX MATCH "define[ \t]+EIGEN_MINOR_VERSION[ \t]+([0-9]+)" _eigen3_minor_version_match "${_eigen3_version_header}") + set(EIGEN3_MINOR_VERSION "${CMAKE_MATCH_1}") + + set(EIGEN3_VERSION ${EIGEN3_WORLD_VERSION}.${EIGEN3_MAJOR_VERSION}.${EIGEN3_MINOR_VERSION}) + if(${EIGEN3_VERSION} VERSION_LESS ${Eigen3_FIND_VERSION}) + set(EIGEN3_VERSION_OK FALSE) + else(${EIGEN3_VERSION} VERSION_LESS ${Eigen3_FIND_VERSION}) + set(EIGEN3_VERSION_OK TRUE) + endif(${EIGEN3_VERSION} VERSION_LESS ${Eigen3_FIND_VERSION}) + + if(NOT EIGEN3_VERSION_OK) + + message(STATUS "Eigen3 version ${EIGEN3_VERSION} found in ${EIGEN3_INCLUDE_DIR}, " + "but at least version ${Eigen3_FIND_VERSION} is required") + endif(NOT EIGEN3_VERSION_OK) +endmacro(_eigen3_check_version) + +if (EIGEN3_INCLUDE_DIR) + + # in cache already + _eigen3_check_version() + set(EIGEN3_FOUND ${EIGEN3_VERSION_OK}) + +else (EIGEN3_INCLUDE_DIR) + + find_path(EIGEN3_INCLUDE_DIR NAMES signature_of_eigen3_matrix_library + PATHS + ${CMAKE_INSTALL_PREFIX}/include + ${KDE4_INCLUDE_DIR} + PATH_SUFFIXES eigen3 eigen + ) + + if(EIGEN3_INCLUDE_DIR) + _eigen3_check_version() + endif(EIGEN3_INCLUDE_DIR) + + include(FindPackageHandleStandardArgs) + find_package_handle_standard_args(Eigen3 DEFAULT_MSG EIGEN3_INCLUDE_DIR EIGEN3_VERSION_OK) + + mark_as_advanced(EIGEN3_INCLUDE_DIR) + +endif(EIGEN3_INCLUDE_DIR) + diff --git a/cmake/FindGMock.cmake b/cmake/FindGMock.cmake new file mode 100644 index 000000000..2ad922129 --- /dev/null +++ b/cmake/FindGMock.cmake @@ -0,0 +1,130 @@ +# Locate the Google C++ Mocking Framework. +# (This file is almost an identical copy of the original FindGTest.cmake file, +# feel free to use it as it is or modify it for your own needs.) +# +# +# Defines the following variables: +# +# GMOCK_FOUND - Found the Google Testing framework +# GMOCK_INCLUDE_DIRS - Include directories +# +# Also defines the library variables below as normal +# variables. These contain debug/optimized keywords when +# a debugging library is found. +# +# GMOCK_BOTH_LIBRARIES - Both libgmock & libgmock-main +# GMOCK_LIBRARIES - libgmock +# GMOCK_MAIN_LIBRARIES - libgmock-main +# +# Accepts the following variables as input: +# +# GMOCK_ROOT - (as a CMake or environment variable) +# The root directory of the gmock install prefix +# +# GMOCK_MSVC_SEARCH - If compiling with MSVC, this variable can be set to +# "MD" or "MT" to enable searching a gmock build tree +# (defaults: "MD") +# +#----------------------- +# Example Usage: +# +# find_package(GMock REQUIRED) +# include_directories(${GMOCK_INCLUDE_DIRS}) +# +# add_executable(foo foo.cc) +# target_link_libraries(foo ${GMOCK_BOTH_LIBRARIES}) +# +#============================================================================= +# This file is released under the MIT licence: +# +# Copyright (c) 2011 Matej Svec +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to +# deal in the Software without restriction, including without limitation the +# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or +# sell copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in +# all copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +# IN THE SOFTWARE. +#============================================================================= + + +function(_gmock_append_debugs _endvar _library) + if(${_library} AND ${_library}_DEBUG) + set(_output optimized ${${_library}} debug ${${_library}_DEBUG}) + else() + set(_output ${${_library}}) + endif() + set(${_endvar} ${_output} PARENT_SCOPE) +endfunction() + +function(_gmock_find_library _name) + find_library(${_name} + NAMES ${ARGN} + HINTS + $ENV{GMOCK_ROOT} + ${GMOCK_ROOT} + PATH_SUFFIXES ${_gmock_libpath_suffixes} + ) + mark_as_advanced(${_name}) +endfunction() + + +if(NOT DEFINED GMOCK_MSVC_SEARCH) + set(GMOCK_MSVC_SEARCH MD) +endif() + +set(_gmock_libpath_suffixes lib) +if(MSVC) + if(GMOCK_MSVC_SEARCH STREQUAL "MD") + list(APPEND _gmock_libpath_suffixes + msvc/gmock-md/Debug + msvc/gmock-md/Release) + elseif(GMOCK_MSVC_SEARCH STREQUAL "MT") + list(APPEND _gmock_libpath_suffixes + msvc/gmock/Debug + msvc/gmock/Release) + endif() +endif() + +find_path(GMOCK_INCLUDE_DIR gmock/gmock.h + HINTS + $ENV{GMOCK_ROOT}/include + ${GMOCK_ROOT}/include +) +mark_as_advanced(GMOCK_INCLUDE_DIR) + +if(MSVC AND GMOCK_MSVC_SEARCH STREQUAL "MD") + # The provided /MD project files for Google Mock add -md suffixes to the + # library names. + _gmock_find_library(GMOCK_LIBRARY gmock-md gmock) + _gmock_find_library(GMOCK_LIBRARY_DEBUG gmock-mdd gmockd) + _gmock_find_library(GMOCK_MAIN_LIBRARY gmock_main-md gmock_main) + _gmock_find_library(GMOCK_MAIN_LIBRARY_DEBUG gmock_main-mdd gmock_maind) +else() + _gmock_find_library(GMOCK_LIBRARY gmock) + _gmock_find_library(GMOCK_LIBRARY_DEBUG gmockd) + _gmock_find_library(GMOCK_MAIN_LIBRARY gmock_main) + _gmock_find_library(GMOCK_MAIN_LIBRARY_DEBUG gmock_maind) +endif() + +include(FindPackageHandleStandardArgs) +FIND_PACKAGE_HANDLE_STANDARD_ARGS(GMock DEFAULT_MSG GMOCK_LIBRARY GMOCK_INCLUDE_DIR GMOCK_MAIN_LIBRARY) + +if(GMOCK_FOUND) + set(GMOCK_INCLUDE_DIRS ${GMOCK_INCLUDE_DIR}) + _gmock_append_debugs(GMOCK_LIBRARIES GMOCK_LIBRARY) + _gmock_append_debugs(GMOCK_MAIN_LIBRARIES GMOCK_MAIN_LIBRARY) + set(GMOCK_BOTH_LIBRARIES ${GMOCK_LIBRARIES} ${GMOCK_MAIN_LIBRARIES}) +endif() + diff --git a/cmake/FindLibDL.cmake b/cmake/FindLibDL.cmake new file mode 100644 index 000000000..1689e4c7f --- /dev/null +++ b/cmake/FindLibDL.cmake @@ -0,0 +1,30 @@ +# - Find libdl +# Find the native LIBDL includes and library +# +# LIBDL_INCLUDE_DIR - where to find dlfcn.h, etc. +# LIBDL_LIBRARIES - List of libraries when using libdl. +# LIBDL_FOUND - True if libdl found. + + +IF (LIBDL_INCLUDE_DIR) + # Already in cache, be silent + SET(LIBDL_FIND_QUIETLY TRUE) +ENDIF (LIBDL_INCLUDE_DIR) + +FIND_PATH(LIBDL_INCLUDE_DIR dlfcn.h) + +SET(LIBDL_NAMES dl libdl ltdl libltdl) +FIND_LIBRARY(LIBDL_LIBRARY NAMES ${LIBDL_NAMES} ) + +# handle the QUIETLY and REQUIRED arguments and set LIBDL_FOUND to TRUE if +# all listed variables are TRUE +INCLUDE(FindPackageHandleStandardArgs) +FIND_PACKAGE_HANDLE_STANDARD_ARGS(LibDL DEFAULT_MSG LIBDL_LIBRARY LIBDL_INCLUDE_DIR) + +IF(LIBDL_FOUND) + SET( LIBDL_LIBRARIES ${LIBDL_LIBRARY} ) +ELSE(LIBDL_FOUND) + SET( LIBDL_LIBRARIES ) +ENDIF(LIBDL_FOUND) + +MARK_AS_ADVANCED( LIBDL_LIBRARY LIBDL_INCLUDE_DIR ) diff --git a/cmake/FindRT.cmake b/cmake/FindRT.cmake new file mode 100644 index 000000000..55ae1a26c --- /dev/null +++ b/cmake/FindRT.cmake @@ -0,0 +1,55 @@ +# - Check for the presence of RT +# +# The following variables are set when RT is found: +# HAVE_RT = Set to true, if all components of RT +# have been found. +# RT_INCLUDES = Include path for the header files of RT +# RT_LIBRARIES = Link these to use RT + +## ----------------------------------------------------------------------------- +## Check for the header files + +find_path (RT_INCLUDES time.h + PATHS /usr/local/include /usr/include ${CMAKE_EXTRA_INCLUDES} + ) + +## ----------------------------------------------------------------------------- +## Check for the library + +find_library (RT_LIBRARIES rt + PATHS /usr/local/lib /usr/lib /lib ${CMAKE_EXTRA_LIBRARIES} + ) + +## ----------------------------------------------------------------------------- +## Actions taken when all components have been found + +if (RT_INCLUDES AND RT_LIBRARIES) + set (HAVE_RT TRUE) +else (RT_INCLUDES AND RT_LIBRARIES) + if (NOT RT_FIND_QUIETLY) + if (NOT RT_INCLUDES) + message (STATUS "Unable to find RT header files!") + endif (NOT RT_INCLUDES) + if (NOT RT_LIBRARIES) + message (STATUS "Unable to find RT library files!") + endif (NOT RT_LIBRARIES) + endif (NOT RT_FIND_QUIETLY) +endif (RT_INCLUDES AND RT_LIBRARIES) + +if (HAVE_RT) + if (NOT RT_FIND_QUIETLY) + message (STATUS "Found components for RT") + message (STATUS "RT_INCLUDES = ${RT_INCLUDES}") + message (STATUS "RT_LIBRARIES = ${RT_LIBRARIES}") + endif (NOT RT_FIND_QUIETLY) +else (HAVE_RT) + if (RT_FIND_REQUIRED) + message (FATAL_ERROR "Could not find RT!") + endif (RT_FIND_REQUIRED) +endif (HAVE_RT) + +mark_as_advanced ( + HAVE_RT + RT_LIBRARIES + RT_INCLUDES + ) diff --git a/compound-split/README.md b/compound-split/README.md new file mode 100644 index 000000000..b7491007a --- /dev/null +++ b/compound-split/README.md @@ -0,0 +1,51 @@ +Instructions for running the compound splitter, which is a reimplementation +and extension (more features, larger non-word list) of the model described in + + C. Dyer. (2009) Using a maximum entropy model to build segmentation + lattices for MT. In Proceedings of NAACL HLT 2009, + Boulder, Colorado, June 2009 + +If you use this software, please cite this paper. + + +GENERATING 1-BEST SEGMENTATIONS AND LATTICES +------------------------------------------------------------------------------ + +Here are some sample invokations: + + ./compound-split.pl --output 1best < infile.txt > out.1best.txt + Segment infile.txt according to the 1-best segmentation file. + + ./compound-split.pl --output plf < infile.txt > out.plf + + ./compound-split.pl --output plf --beam 3.5 < infile.txt > out.plf + This generates denser lattices than usual (the default beam threshold + is 2.2, higher numbers do less pruning) + + +MODEL TRAINING (only for the adventuresome) +------------------------------------------------------------------------------ + +I've included some training data for training a German language lattice +segmentation model, and if you want to explore, you can or change the data. +If you're especially adventuresome, you can add features to cdec (the current +feature functions are found in ff_csplit.cc). The training/references are +in the file: + + dev.in-ref + +The format is the unsegmented form on the right and the reference lattice on +the left, separated by a triple pipe ( ||| ). Note that the segmentation +model inserts a # as the first word, so your segmentation references must +include this. + +To retrain the model (using MAP estimation of a conditional model), do the +following: + + cd de + ./TRAIN + +Note, the optimization objective is supposed to be non-convex, but i haven't +found much of an effect of where I initialize things. But I haven't looked +very hard- this might be something to explore. + diff --git a/compound-split/cdec-de.ini b/compound-split/cdec-de.ini new file mode 100644 index 000000000..1573dd522 --- /dev/null +++ b/compound-split/cdec-de.ini @@ -0,0 +1,6 @@ +formalism=csplit +intersection_strategy=full +weights=de/weights.trained +#weights=de/weights.noun-only-1best-only +feature_function=CSplit_BasicFeatures de/large_dict.de.gz de/badlist.de.gz de/wordlist.de +feature_function=CSplit_ReverseCharLM de/charlm.rev.5gm.de.lm.gz diff --git a/compound-split/compound-split.pl b/compound-split/compound-split.pl new file mode 100755 index 000000000..93ac3b201 --- /dev/null +++ b/compound-split/compound-split.pl @@ -0,0 +1,177 @@ +#!/usr/bin/perl -w + +use strict; +my $script_dir; BEGIN { use Cwd qw/ abs_path /; use File::Basename; $script_dir = dirname(abs_path($0)); push @INC, $script_dir; } +use Getopt::Long; +use IPC::Open2; + +my $CDEC = "$script_dir/../decoder/cdec"; +my $LANG = 'de'; + +my $BEAM = 2.1; +my $OUTPUT = 'plf'; +my $HELP; +my $VERBOSE; +my $PRESERVE_CASE; + +GetOptions("decoder=s" => \$CDEC, + "language=s" => \$LANG, + "beam=f" => \$BEAM, + "output=s" => \$OUTPUT, + "verbose" => \$VERBOSE, + "preserve_case" => \$PRESERVE_CASE, + "help" => \$HELP + ) or usage(); + +usage() if $HELP; + +chdir $script_dir; + +if ($VERBOSE) { $VERBOSE = ""; } else { $VERBOSE = " 2> /dev/null"; } +$LANG = lc $LANG; +die "Can't find $CDEC\n" unless -f $CDEC; +die "Can't execute $CDEC\n" unless -x $CDEC; +die "Don't know about language: $LANG\n" unless -d "./$LANG"; +my $CONFIG="cdec-$LANG.ini"; +die "Can't find $CONFIG" unless -f $CONFIG; +die "--output must be '1best' or 'plf'\n" unless ($OUTPUT =~ /^(plf|1best)$/); +check_dependencies($CONFIG, $LANG); +print STDERR "(Run with --help for options)\n"; +print STDERR "LANGUAGE: $LANG\n"; +print STDERR " OUTPUT: $OUTPUT\n"; + +my $CMD = "$CDEC -c $CONFIG"; +my $IS_PLF; +if ($OUTPUT eq 'plf') { + $IS_PLF = 1; + $CMD .= " --csplit_preserve_full_word --csplit_output_plf --beam_prune $BEAM"; +} +$CMD .= $VERBOSE; + +print STDERR "Executing: $CMD\n"; + +open2(\*OUT, \*IN, $CMD) or die "Couldn't fork: $!"; +binmode(STDIN,":utf8"); +binmode(STDOUT,":utf8"); +binmode(IN,":utf8"); +binmode(OUT,":utf8"); + +while() { + chomp; + s/^\s+//; + s/\s+$//; + my @words = split /\s+/; + my @res = (); + my @todo = (); + my @casings = (); + for (my $i=0; $i < scalar @words; $i++) { + my $word = lc $words[$i]; + if (length($word)<6 || $word =~ /^[,\-0-9\.]+$/ || $word =~ /[@.\-\/:]/) { + push @casings, 0; + if ($IS_PLF) { + push @res, "(('" . escape($word) . "',0,1),),"; + } else { + if ($PRESERVE_CASE) { + push @res, $words[$i]; + } else { + push @res, $word; + } + } + } else { + push @casings, guess_casing($words[$i]); + push @res, undef; + push @todo, $word; + } + } + if (scalar @todo > 0) { + # print STDERR "TODO: @todo\n"; + my $tasks = join "\n", @todo; + print IN "$tasks\n"; + for (my $i = 0; $i < scalar @res; $i++) { + if (!defined $res[$i]) { + my $seg = ; + chomp $seg; + unless ($IS_PLF) { + $seg =~ s/^# //o; + } + if ($PRESERVE_CASE && $casings[$i]) { $seg = recase_words($seg); } + $res[$i] = $seg; + } + } + } + if ($IS_PLF) { + print '('; + print join '', @res; + print ")\n"; + } else { + print "@res\n"; + } +} + +close IN; +close OUT; + +sub recase_words { + my $word = shift; + $word =~ s/\b(\w)/\u$1/g; + return $word; +} + +sub escape { + $_ = shift; + s/\\/\\\\/g; + s/'/\\'/g; + return $_; +} + +sub guess_casing { + my $word = shift @_; + if (lc($word) eq $word) { return 0; } else { return 1; } +} + +sub usage { + print <){ + chomp; + my @x = split /\s+/; + for my $f (@x) { + push @files, $f if ($f =~ /\.gz$/); + } + } + close F; + my $c = 0; + for my $file (@files) { + $c++ if -f $file; + } + if ($c != scalar @files) { + print STDERR <) { + chomp; + s/[\–":“„!=+*.@«#%&,»\?\/{}\$\(\)\[\];\-0-9]+/ /g; + $_ = lc $_; + my @words = split /\s+/; + for my $w (@words) { + next if length($w) == 0; + $d{$w}++; + $z++; + } +} +my $lz = log($z); +for my $w (sort {$d{$b} <=> $d{$a}} keys %d) { + my $c = $lz-log($d{$w}); + print "$w $c\n"; +} + diff --git a/config.h.cmake b/config.h.cmake new file mode 100644 index 000000000..a37f63887 --- /dev/null +++ b/config.h.cmake @@ -0,0 +1,10 @@ +#ifndef CONFIG_H +#define CONFIG_H + +#cmakedefine METEOR_JAR "@METEOR_JAR@" +#cmakedefine HAVE_ZLIB @HAVE_ZLIB@ +#cmakedefine HAVE_BZLIB @HAVE_BZLIB@ +#cmakedefine HAVE_XZLIB @HAVE_XZLIB@ +#cmakedefine HAVE_EIGEN @HAVE_EIGEN@ + +#endif // CONFIG_H diff --git a/corpus/README.md b/corpus/README.md new file mode 100644 index 000000000..adc35b849 --- /dev/null +++ b/corpus/README.md @@ -0,0 +1,37 @@ +This directory contains a number of useful scripts that are helpful for preprocessing parallel and monolingual corpora. They are provided for convenience and may be very useful, but their functionality will often be supplainted by other, more specialized tools. + +Many of these scripts assume that the input is [UTF-8 encoded](http://en.wikipedia.org/wiki/UTF-8). + +## Paste parallel files together + +This script reads one line at a time from a set of files and concatenates them with a triple pipe separator (`|||`) in the output. This is useful for generating parallel corpora files for training or evaluation: + + ./paste-files.pl file.a file.b file.c [...] + +## Punctuation Normalization and Tokenization + +This script tokenizes text in any language (well, it does a good job in most languages, and in some it will completely go crazy): + + ./tokenize-anything.sh < input.txt > output.txt + +It also normalizes a lot of unicode symbols and even corrects some common encoding errors. It can be applied to monolingual and parallel corpora directly. + +## Text lowercasing + +This script also does what it says, provided your input is in UTF8: + + ./lowercase.pl < input.txt > output.txt + +## Length ratio filtering (for parallel corpora) + +This script computes statistics about sentence length ratios in a parallel corpus and removes sentences that are statistical outliers. This tends to remove extremely poorly aligned sentence pairs or sentence pairs that would otherwise be difficult to align: + + ./filter-length.pl input.src-trg > output.src-trg + +## Add infrequent self-transaltions to a parallel corpus + +This script identifies rare words (those that occur less than 2 times in the corpus) and which have the same orthographic form in both the source and target language. Several copies of these words are then inserted at the end of the corpus that is written, which improves alignment quality. + + ./add-self-translations.pl input.src-trg > output.src-trg + + diff --git a/corpus/add-self-translations.pl b/corpus/add-self-translations.pl new file mode 100755 index 000000000..d707ce29c --- /dev/null +++ b/corpus/add-self-translations.pl @@ -0,0 +1,29 @@ +#!/usr/bin/perl -w +use strict; + +# ADDS SELF-TRANSLATIONS OF POORLY ATTESTED WORDS TO THE PARALLEL DATA + +my %df; +my %def; +while(<>) { +# print; + chomp; + my ($sf, $se) = split / \|\|\| /; + die "Format error: $_\n" unless defined $sf && defined $se; + my @fs = split /\s+/, $sf; + my @es = split /\s+/, $se; + for my $f (@fs) { + $df{$f}++; + for my $e (@es) { + if ($f eq $e) { $def{$f}++; } + } + } +} + +for my $k (sort keys %def) { + next if $df{$k} > 4; + print "$k ||| $k\n"; + print "$k ||| $k\n"; + print "$k ||| $k\n"; +} + diff --git a/corpus/add-sos-eos.pl b/corpus/add-sos-eos.pl new file mode 100755 index 000000000..d7608c5ec --- /dev/null +++ b/corpus/add-sos-eos.pl @@ -0,0 +1,63 @@ +#!/usr/bin/perl -w +use strict; + +die "Usage: $0 corpus.fr[-en1-en2-...] [corpus.al out-corpus.al]\n" unless (scalar @ARGV == 1 || scalar @ARGV == 3); +my $filec = shift @ARGV; +my $filea = shift @ARGV; +my $ofilea = shift @ARGV; +open C, "<$filec" or die "Can't read $filec: $!"; +if ($filea) { + open A, "<$filea" or die "Can't read $filea: $!"; + open OA, ">$ofilea" or die "Can't write $ofilea: $!"; +} +binmode(C, ":utf8"); +binmode(STDOUT, ":utf8"); +print STDERR "Adding and markers to input...\n"; +print STDERR " Reading corpus: $filec\n"; +print STDERR " Writing corpus: STDOUT\n"; +print STDERR "Reading alignments: $filea\n" if $filea; +print STDERR "Writing alignments: $ofilea\n" if $filea; + +my $lines = 0; +while() { + $lines++; + die "ERROR. Input line $filec:$lines should not contain SGML markup" if /; + die "ERROR. Mismatched number of lines between $filec and $filea\n" unless $aa; + chomp $aa; + my ($ff, $ee) = @fields; + die "ERROR in $filec:$lines: expected 'source ||| target'" unless defined $ee; + my @fs = split /\s+/, $ff; + my @es = split /\s+/, $ee; + my @as = split /\s+/, $aa; + my @oas = (); + push @oas, '0-0'; + my $flen = scalar @fs; + my $elen = scalar @es; + for my $ap (@as) { + my ($a, $b) = split /-/, $ap; + die "ERROR. Bad format in: @as" unless defined $a && defined $b; + push @oas, ($a + 1) . '-' . ($b + 1); + } + push @oas, ($flen + 1) . '-' . ($elen + 1); + print OA "@oas\n"; + } + print "$o\n"; +} +if ($filea) { + close OA; + my $aa = ; + die "ERROR. Alignment input file $filea contains more lines than corpus file!\n" if $aa; +} +print STDERR "\nSUCCESS. Processed $lines lines.\n"; + diff --git a/corpus/conll2cdec.pl b/corpus/conll2cdec.pl new file mode 100755 index 000000000..ee4e07dbf --- /dev/null +++ b/corpus/conll2cdec.pl @@ -0,0 +1,42 @@ +#!/usr/bin/perl -w +use strict; + +die "Usage: $0 file.conll\n\n Converts a CoNLL formatted labeled sequence into cdec's format.\n\n" unless scalar @ARGV == 1; +open F, "<$ARGV[0]" or die "Can't read $ARGV[0]: $!\n"; + +my @xx; +my @yy; +my @os; +my $sec = undef; +my $i = 0; +while() { + chomp; + if (/^\s*$/) { + print "[$j]; + $sym =~ s/"/'/g; + push @oo, $sym; + } + my $zz = $j + 1; + print " feat$zz=\"@oo\""; + } + + print "> @xx ||| @yy \n"; + @xx = (); + @yy = (); + @os = (); + } else { + my ($x, @fs) = split /\s+/; + my $y = pop @fs; + if (!defined $sec) { $sec = scalar @fs; } + die unless $sec == scalar @fs; + push @xx, $x; + push @yy, $y; + push @os, \@fs; + } +} + diff --git a/corpus/cut-corpus.pl b/corpus/cut-corpus.pl new file mode 100755 index 000000000..0af3b23ca --- /dev/null +++ b/corpus/cut-corpus.pl @@ -0,0 +1,35 @@ +#!/usr/bin/perl -w +use strict; +die "Usage: $0 N\nSplits a corpus separated by ||| symbols and returns the Nth field\n" unless scalar @ARGV > 0; + +my $x = shift @ARGV; +my @ind = split /,/, $x; +my @o = (); +for my $ff (@ind) { + if ($ff =~ /^\d+$/) { + push @o, $ff - 1; + } elsif ($ff =~ /^(\d+)-(\d+)$/) { + my $a = $1; + my $b = $2; + die "$a-$b is a bad range in input: $x\n" unless $b > $a; + for (my $i=$a; $i <= $b; $i++) { + push @o, $i - 1; + } + } else { + die "Bad input: $x\n"; + } +} + +while(<>) { + chomp; + my @fields = split /\s*\|\|\|\s*/; + my @sf; + for my $i (@o) { + my $y = $fields[$i]; + if (!defined $y) { $y= ''; } + push @sf, $y; + } + print join(' ||| ', @sf) . "\n"; +} + + diff --git a/corpus/filter-length.pl b/corpus/filter-length.pl new file mode 100755 index 000000000..8b73a1c86 --- /dev/null +++ b/corpus/filter-length.pl @@ -0,0 +1,152 @@ +#!/usr/bin/perl -w +use strict; +use utf8; + +##### EDIT THESE SETTINGS #################################################### +my $AUTOMATIC_INCLUDE_IF_SHORTER_THAN = 7; # if both are shorter, include +my $MAX_ZSCORE = 1.8; # how far from the mean can the (log)ratio be? +############################################################################## + +die "Usage: $0 [-NNN] corpus.fr-en\n\n Filter sentence pairs containing sentences longer than NNN words (where NNN\n is 150 by default) or whose log length ratios are $MAX_ZSCORE stddevs away from the\n mean log ratio.\n\n" unless scalar @ARGV == 1 || scalar @ARGV == 2; +binmode(STDOUT,":utf8"); +binmode(STDERR,":utf8"); + +my $MAX_LENGTH = 150; # discard a sentence if it is longer than this +if (scalar @ARGV == 2) { + my $fp = shift @ARGV; + die "Expected -NNN for first parameter, but got $fp\n" unless $fp =~ /^-(\d+)$/; + $MAX_LENGTH=$1; +} + +my $corpus = shift @ARGV; + +die "Cannot read from STDIN\n" if $corpus eq '-'; +my $ff = "<$corpus"; +$ff = "gunzip -c $corpus|" if $ff =~ /\.gz$/; + +print STDERR "Max line length (monolingual): $MAX_LENGTH\n"; +print STDERR " Parallel corpus: $corpus\n"; + +open F,$ff or die "Can't read $corpus: $!"; +binmode(F,":utf8"); + +my $rat_max = log(9); +my $lrm = 0; +my $zerof = 0; +my $zeroe = 0; +my $bad_format = 0; +my $absbadrat = 0; +my $overlene = 0; +my $overlenf = 0; +my $lines = 0; +my @lograts = (); +while() { + $lines++; + if ($lines % 100000 == 0) { print STDERR " [$lines]\n"; } + elsif ($lines % 2500 == 0) { print STDERR "."; } + my ($sf, $se, @d) = split /\s*\|\|\|\s*/; + if (scalar @d != 0 or !defined $se) { + $bad_format++; + if ($bad_format > 100 && ($bad_format / $lines) > 0.02) { + die "$bad_format / $lines : Corpus appears to be incorretly formatted, example: $_"; + } + next; + } + my @fs = (); + my @es = (); + if (defined $sf && length($sf) > 0) { @fs = split /\s+/, $sf; } + if (defined $se && length($se) > 0) { @es = split /\s+/, $se; } + my $flen = scalar @fs; + my $elen = scalar @es; + if ($flen == 0) { + $zerof++; + next; + } + if ($elen == 0) { + $zeroe++; + next; + } + if ($flen > $MAX_LENGTH) { + $overlenf++; + next; + } + if ($elen > $MAX_LENGTH) { + $overlene++; + next; + } + if ($elen >= $AUTOMATIC_INCLUDE_IF_SHORTER_THAN || + $flen >= $AUTOMATIC_INCLUDE_IF_SHORTER_THAN) { + my $lograt = log($flen) - log($elen); + if (abs($lograt) > $rat_max) { + $absbadrat++; + next; + } + $lrm += $lograt; + push @lograts, $lograt; + } +} +close F; + +print STDERR "\nComputing statistics...\n"; +my $lmean = $lrm / scalar @lograts; + +my $lsd = 0; +for my $lr (@lograts) { + $lsd += ($lr - $lmean)**2; +} +$lsd = sqrt($lsd / scalar @lograts); +@lograts = (); + +my $pass1_discard = $zerof + $zeroe + $absbadrat + $overlene + $overlenf + $bad_format; +my $discard_rate = int(10000 * $pass1_discard / $lines) / 100; +print STDERR " Total lines: $lines\n"; +print STDERR " Already discared: $pass1_discard\t(discard rate = $discard_rate%)\n"; +print STDERR " Mean F:E ratio: " . exp($lmean) . "\n"; +print STDERR " StdDev F:E ratio: " . exp($lsd) . "\n"; +print STDERR "Writing...\n"; +open F,$ff or die "Can't reread $corpus: $!"; +binmode(F,":utf8"); +my $to = 0; +my $zviol = 0; +my $worstz = -1; +my $worst = "\n"; +$lines = 0; +while() { + $lines++; + if ($lines % 100000 == 0) { print STDERR " [$lines]\n"; } + elsif ($lines % 2500 == 0) { print STDERR "."; } + my ($sf, $se, @d) = split / \|\|\| /; + if (!defined $se) { next; } + my @fs = split /\s+/, $sf; + my @es = split /\s+/, $se; + my $flen = scalar @fs; + my $elen = scalar @es; + next if ($flen == 0); + next if ($elen == 0); + next if ($flen > $MAX_LENGTH); + next if ($elen > $MAX_LENGTH); + if ($elen >= $AUTOMATIC_INCLUDE_IF_SHORTER_THAN || + $flen >= $AUTOMATIC_INCLUDE_IF_SHORTER_THAN) { + my $lograt = log($flen) - log($elen); + if (abs($lograt) > $rat_max) { + $absbadrat++; + next; + } + my $zscore = abs($lograt - $lmean) / $lsd; + if ($elen > $AUTOMATIC_INCLUDE_IF_SHORTER_THAN && + $flen > $AUTOMATIC_INCLUDE_IF_SHORTER_THAN && $zscore > $worstz) { $worstz = $zscore; $worst = $_; } + if ($zscore > $MAX_ZSCORE) { + $zviol++; + next; + } + print; + } else { + print; + } + $to++; +} +my $discard_rate2 = int(10000 * $zviol / ($lines - $pass1_discard)) / 100; +print STDERR "\n Lines printed: $to\n Ratio violations: $zviol\t(discard rate = $discard_rate2%)\n"; +print STDERR " Worst z-score: $worstz\n sentence: $worst"; +exit 0; + diff --git a/corpus/lowercase.pl b/corpus/lowercase.pl new file mode 100755 index 000000000..9fd91dac2 --- /dev/null +++ b/corpus/lowercase.pl @@ -0,0 +1,9 @@ +#!/usr/bin/perl -w +use strict; +binmode(STDIN,":utf8"); +binmode(STDOUT,":utf8"); +while() { + $_ = lc $_; + print; +} + diff --git a/corpus/moses-scfg-to-cdec.pl b/corpus/moses-scfg-to-cdec.pl new file mode 100755 index 000000000..9b8e36179 --- /dev/null +++ b/corpus/moses-scfg-to-cdec.pl @@ -0,0 +1,69 @@ +#!/usr/bin/perl -w +use strict; + +while(<>) { + my ($src, $trg, $feats, $al) = split / \|\|\| /; + # [X][NP] von [X][NP] [X] ||| [X][NP] 's [X][NP] [S] ||| 0.00110169 0.0073223 2.84566e-06 0.0027702 0.0121867 2.718 0.606531 ||| 0-0 1-1 2-2 ||| 635 245838 2 + + my @srcs = split /\s+/, $src; + my @trgs = split /\s+/, $trg; + my $lhs = pop @trgs; + $lhs =~ s/&apos;/'/g; + $lhs =~ s/'/'/g; + $lhs =~ s/,/COMMA/g; + my $ntc = 0; + my $sc = 0; + my @of = (); + my $x = pop @srcs; + my %d = (); # src index to nonterminal count + die "Expected [X]" unless $x eq '[X]'; + my %amap = (); + my @als = split / /, $al; + for my $st (@als) { + my ($s, $t) = split /-/, $st; + $amap{$t} = $s; + } + for my $f (@srcs) { + if ($f =~ /^\[X\]\[([^]]+)\]$/) { + $ntc++; + my $nt = $1; + $nt =~ s/&apos;/'/g; + $nt =~ s/'/'/g; + $nt =~ s/,/COMMA/g; + push @of, "[$nt]"; + $d{$sc} = $ntc; + } elsif ($f =~ /^\[[^]]+\]$/) { + die "Unexpected $f"; + } else { + push @of, $f; + } + $sc++; + } + my @oe = (); + my $ind = 0; + for my $e (@trgs) { + if ($e =~ /^\[X\]\[([^]]+)\]$/) { + my $imap = $d{$amap{$ind}}; + push @oe, "[$imap]"; + } else { + push @oe, $e; + } + $ind++; + } + my ($fe, $ef, $j, $lfe, $lef, $dummy, $of) = split / /, $feats; + next if $lef eq '0'; + next if $lfe eq '0'; + next if $ef eq '0'; + next if $fe eq '0'; + next if $j eq '0'; + next if $of eq '0'; + $ef = sprintf('%.6g', log($ef)); + $fe = sprintf('%.6g',log($fe)); + $j = sprintf('%.6g',log($j)); + $lef = sprintf('%.6g',log($lef)); + $lfe = sprintf('%.6g',log($lfe)); + $of = sprintf('%.6g',log($of)); + print "$lhs ||| @of ||| @oe ||| RuleCount=1 FgivenE=$fe EgivenF=$ef Joint=$j LexEgivenF=$lef LexFgivenE=$lfe Other=$of\n"; +} + +# [X][ADVP] angestiegen [X] ||| rose [X][ADVP] [VP] ||| 0.0538131 0.0097508 0.00744224 0.0249653 0.000698602 2.718 0.606531 ||| 0-1 1-0 ||| 13 94 2 diff --git a/corpus/moses-xml.pl b/corpus/moses-xml.pl new file mode 100755 index 000000000..fca63aa8a --- /dev/null +++ b/corpus/moses-xml.pl @@ -0,0 +1,36 @@ +#!/usr/bin/perl -w + +use strict; +$|++; + +my $msg = "Usage: $0 (escape|unescape)\n\n Escapes XMl entities and other special characters for use with Moses.\n\n"; + +die $msg unless scalar @ARGV == 1; + +if ($ARGV[0] eq "escape") { + while () { + $_ =~ s/\&/\&/g; # escape escape + $_ =~ s/\|/\|/g; # factor separator + $_ =~ s/\/\>/g; # xml + $_ =~ s/\'/\'/g; # xml + $_ =~ s/\"/\"/g; # xml + $_ =~ s/\[/\[/g; # syntax non-terminal + $_ =~ s/\]/\]/g; # syntax non-terminal + print; + } +} elsif ($ARGV[0] eq "unescape") { + while () { + $_ =~ s/\|/\|/g; # factor separator + $_ =~ s/\</\/g; # xml + $_ =~ s/\'/\'/g; # xml + $_ =~ s/\"/\"/g; # xml + $_ =~ s/\[/\[/g; # syntax non-terminal + $_ =~ s/\]/\]/g; # syntax non-terminal + $_ =~ s/\&/\&/g; # escape escape + print; + } +} else { + die $msg; +} diff --git a/corpus/paste-files.pl b/corpus/paste-files.pl new file mode 100755 index 000000000..ef2cd9370 --- /dev/null +++ b/corpus/paste-files.pl @@ -0,0 +1,61 @@ +#!/usr/bin/perl -w +use strict; + +die "Usage: $0 file1.txt file2.txt [file3.txt ...]\n\n Performs a per-line concatenation of all files using the ||| seperator.\n\n" unless scalar @ARGV > 1; + +my @fhs = (); +for my $file (@ARGV) { + my $fh; + if ($file =~ /\.gz$/) { + open $fh, "gunzip -c $file|" or die "Can't fork gunzip -c $file: $!"; + } else { + open $fh, "<$file" or die "Can't read $file: $!"; + } + binmode($fh,":utf8"); + push @fhs, $fh; +} +binmode(STDOUT,":utf8"); +binmode(STDERR,":utf8"); + +my $bad = 0; +my $lc = 0; +my $done = 0; +my $fl = 0; +while(1) { + my @line; + $lc++; + if ($lc % 100000 == 0) { print STDERR " [$lc]\n"; $fl = 0; } + elsif ($lc % 2500 == 0) { print STDERR "."; $fl = 1; } + my $anum = 0; + for my $fh (@fhs) { + my $r = <$fh>; + if (!defined $r) { + die "Mismatched number of lines.\n" if scalar @line > 0; + $done = 1; + last; + } + $r =~ s/\r//g; + chomp $r; + if ($r =~ /\|\|\|/) { + $r = ''; + $bad++; + } + warn "$ARGV[$anum]:$lc contains a ||| symbol - please remove.\n" if $r =~ /\|\|\|/; + $r =~ s/\|\|\|/ /g; + $r =~ s/\s+/ /g; + $r =~ s/^ +//; + $r =~ s/ +$//; + $anum++; + push @line, $r; + } + last if $done; + print STDOUT join(' ||| ', @line) . "\n"; +} +print STDERR "\n" if $fl; +for (my $i = 1; $i < scalar @fhs; $i++) { + my $fh = $fhs[$i]; + my $r = <$fh>; + die "Mismatched number of lines.\n" if defined $r; +} +print STDERR "Number of lines containing ||| was: $bad\n" if $bad > 0; + diff --git a/corpus/sample-dev-sets.py b/corpus/sample-dev-sets.py new file mode 100755 index 000000000..3c969bbe7 --- /dev/null +++ b/corpus/sample-dev-sets.py @@ -0,0 +1,74 @@ +#!/usr/bin/env python + +import gzip +import os +import sys + +HELP = '''Process an input corpus by dividing it into pseudo-documents and uniformly +sampling train and dev sets (simulate uniform sampling at the document level +when document boundaries are unknown) + +usage: {} in_file out_prefix doc_size docs_per_dev_set dev_sets [-lc] +recommended: doc_size=20, docs_per_dev_set=100, dev_sets=2 (dev and test) +''' + +def gzopen(f): + return gzip.open(f, 'rb') if f.endswith('.gz') else open(f, 'r') + +def wc(f): + return sum(1 for _ in gzopen(f)) + +def main(argv): + + if len(argv[1:]) < 5: + sys.stderr.write(HELP.format(os.path.basename(argv[0]))) + sys.exit(2) + + # Args + in_file = os.path.abspath(argv[1]) + out_prefix = os.path.abspath(argv[2]) + doc_size = int(argv[3]) + docs_per_dev_set = int(argv[4]) + dev_sets = int(argv[5]) + lc = (len(argv[1:]) == 6 and argv[6] == '-lc') + + # Compute sizes + corpus_size = wc(in_file) + total_docs = corpus_size / doc_size + leftover = corpus_size % doc_size + train_docs = total_docs - (dev_sets * docs_per_dev_set) + train_batch_size = (train_docs / docs_per_dev_set) + + # Report + sys.stderr.write('Splitting {} lines ({} documents)\n'.format(corpus_size, total_docs + (1 if leftover else 0))) + sys.stderr.write('Train: {} ({})\n'.format((train_docs * doc_size) + leftover, train_docs + (1 if leftover else 0))) + sys.stderr.write('Dev: {} x {} ({})\n'.format(dev_sets, docs_per_dev_set * doc_size, docs_per_dev_set)) + + inp = gzopen(in_file) + train_out = open('{}.train'.format(out_prefix), 'w') + dev_out = [open('{}.dev.{}'.format(out_prefix, i + 1), 'w') for i in range(dev_sets)] + i = 0 + + # For each set of documents + for _ in range(docs_per_dev_set): + # Write several documents to train + for _ in range(train_batch_size): + for _ in range(doc_size): + i += 1 + train_out.write('{} ||| {}'.format(i, inp.readline()) if lc else inp.readline()) + # Write a document to each dev + for out in dev_out: + for _ in range(doc_size): + i += 1 + out.write('{} ||| {}'.format(i, inp.readline()) if lc else inp.readline()) + # Write leftover lines to train + for line in inp: + i += 1 + train_out.write('{} ||| {}'.format(i, line) if lc else line) + + train_out.close() + for out in dev_out: + out.close() + +if __name__ == '__main__': + main(sys.argv) diff --git a/corpus/support/README b/corpus/support/README new file mode 100644 index 000000000..fdbd523e7 --- /dev/null +++ b/corpus/support/README @@ -0,0 +1,2 @@ +Run ./tokenize.sh to tokenize text +Edit eng_token_patterns and eng_token_list to add rules for things not to segment diff --git a/corpus/support/fix-contract.pl b/corpus/support/fix-contract.pl new file mode 100755 index 000000000..49e889812 --- /dev/null +++ b/corpus/support/fix-contract.pl @@ -0,0 +1,12 @@ +#!/usr/bin/perl -w +$|++; + +use strict; +while(<>) { + #s/ (pre|anti|re|pro|inter|intra|multi|e|x|neo) - / $1- /ig; + #s/ - (year) - (old)/ -$1-$2/ig; + s/ ' (s|m|ll|re|d|ve) / '$1 /ig; + s/n ' t / n't /ig; + print; +} + diff --git a/corpus/support/fix-eos.pl b/corpus/support/fix-eos.pl new file mode 100755 index 000000000..fe03727b2 --- /dev/null +++ b/corpus/support/fix-eos.pl @@ -0,0 +1,12 @@ +#!/usr/bin/perl -w +$|++; + +use strict; +use utf8; + +binmode(STDIN, ":utf8"); +binmode(STDOUT, ":utf8"); +while() { + s/(\p{Devanagari}{2}[A-Za-z0-9! ,.\@\p{Devanagari}]+?)\s+(\.)(\s*$|\s+\|\|\|)/$1 \x{0964}$3/s; + print; +} diff --git a/corpus/support/quote-norm.pl b/corpus/support/quote-norm.pl new file mode 100755 index 000000000..3eee06669 --- /dev/null +++ b/corpus/support/quote-norm.pl @@ -0,0 +1,193 @@ +#!/usr/bin/perl -w +$|++; +use strict; +use utf8; +binmode(STDIN,"utf8"); +binmode(STDOUT,"utf8"); +while() { + chomp; + $_ = " $_ "; + + # Delete control characters: + s/[\x{00}-\x{1f}]//g; + + # PTB --> normal + s/-LRB-/(/g; + s/-RRB-/)/g; + s/-LSB-/[/g; + s/-RSB-/]/g; + s/-LCB-/{/g; + s/-RCB-/}/g; + s/ gon na / gonna /g; + + # Regularize named HTML/XML escapes: + s/&\s*lt\s*;//gi; # HTML closing angle bracket + s/&\s*squot\s*;/'/gi; # HTML single quote + s/&\s*quot\s*;/"/gi; # HTML double quote + s/&\s*nbsp\s*;/ /gi; # HTML non-breaking space + s/'/\'/g; # HTML apostrophe + s/&\s*amp\s*;/&/gi; # HTML ampersand (last) + + # Regularize known HTML numeric codes: + s/&\s*#\s*160\s*;/ /gi; # no-break space + s/&\s*#45\s*;\s*&\s*#45\s*;/--/g; # hyphen-minus hyphen-minus + s/&\s*#45\s*;/--/g; # hyphen-minus + + # Convert arbitrary hex or decimal HTML entities to actual characters: + s/&\#x([0-9A-Fa-f]+);/pack("U", hex($1))/ge; + s/&\#([0-9]+);/pack("U", $1)/ge; + + # Regularlize spaces: + s/\x{ad}//g; # soft hyphen + s/\x{200C}//g; # zero-width non-joiner + s/\x{a0}/ /g; # non-breaking space + s/\x{2009}/ /g; # thin space + s/\x{2028}/ /g; # "line separator" + s/\x{2029}/ /g; # "paragraph separator" + s/\x{202a}/ /g; # "left-to-right embedding" + s/\x{202b}/ /g; # "right-to-left embedding" + s/\x{202c}/ /g; # "pop directional formatting" + s/\x{202d}/ /g; # "left-to-right override" + s/\x{202e}/ /g; # "right-to-left override" + s/\x{85}/ /g; # "next line" + s/\x{fffd}/ /g; # "replacement character" + s/\x{feff}/ /g; # byte-order mark + s/\x{fdd3}/ /g; # "unicode non-character" + + # Convert other Windows 1252 characters to UTF-8 + s/\x{80}/\x{20ac}/g; # euro sign + s/\x{95}/\x{2022}/g; # bullet + s/\x{99}/\x{2122}/g; # trademark sign + + # Currency and measure conversions: + s/ (\d\d): (\d\d)/ $1:$2/g; + s/[\x{20a0}]\x{20ac}]/ EUR /g; + s/[\x{00A3}]/ GBP /g; + s/(\W)([A-Z]+\$?)(\d*\.\d+|\d+)/$1$2 $3/g; + s/(\W)(euro?)(\d*\.\d+|\d+)/$1EUR $3/gi; + + # Ridiculous double conversions, UTF8 -> Windows 1252 -> UTF8: + s/�c/--/g; # long dash + s/\x{e2}\x{20ac}oe/\"/g; # opening double quote + s/\x{e2}\x{20ac}\x{9c}/\"/g; # opening double quote + s/\x{e2}\x{20ac}\x{9d}/\"/g; # closing double quote + s/\x{e2}\x{20ac}\x{2122}/\'/g; # apostrophe + s/\x{e2}\x{20ac}\x{201c}/ -- /g; # en dash? + s/\x{e2}\x{20ac}\x{201d}/ -- /g; # em dash? + s/â(\x{80}\x{99}|\x{80}\x{98})/'/g; # single quote? + s/â(\x{80}\x{9c}|\x{80}\x{9d})/"/g; # double quote? + s/\x{c3}\x{9f}/\x{df}/g; # esset + s/\x{c3}\x{0178}/\x{df}/g; # esset + s/\x{c3}\x{a4}/\x{e4}/g; # a umlaut + s/\x{c3}\x{b6}/\x{f6}/g; # o umlaut + s/\x{c3}\x{bc}/\x{fc}/g; # u umlaut + s/\x{c3}\x{84}/\x{c4}/g; # A umlaut: create no C4s after this + s/\x{c3}\x{201e}/\x{c4}/g; # A umlaut: create no C4s after this + s/\x{c3}\x{96}/\x{d6}/g; # O umlaut + s/\x{c3}\x{2013}/\x{d6}/g; # O umlaut + s/\x{c3}\x{bc}/\x{dc}/g; # U umlaut + s/\x{80}/\x{20ac}/g; # euro sign + s/\x{95}/\x{2022}/g; # bullet + s/\x{99}/\x{2122}/g; # trademark sign + + # Regularize quotes: + s/ˇ/'/g; # caron + s/´/'/g; # acute accent + s/`/'/g; # grave accent + s/ˉ/'/g; # modified letter macron + s/ ,,/ "/g; # ghetto low-99 quote + s/``/"/g; # latex-style left quote + s/''/"/g; # latex-style right quote + s/\x{300c}/"/g; # left corner bracket + s/\x{300d}/"/g; # right corner bracket + s/\x{3003}/"/g; # ditto mark + s/\x{00a8}/"/g; # diaeresis + s/\x{92}/\'/g; # curly apostrophe + s/\x{2019}/\'/g; # curly apostrophe + s/\x{f03d}/\'/g; # curly apostrophe + s/\x{b4}/\'/g; # curly apostrophe + s/\x{2018}/\'/g; # curly single open quote + s/\x{201a}/\'/g; # low-9 quote + s/\x{93}/\"/g; # curly left quote + s/\x{201c}/\"/g; # curly left quote + s/\x{94}/\"/g; # curly right quote + s/\x{201d}/\"/g; # curly right quote + s/\x{2033}/\"/g; # curly right quote + s/\x{201e}/\"/g; # low-99 quote + s/\x{84}/\"/g; # low-99 quote (bad enc) + s/\x{201f}/\"/g; # high-rev-99 quote + s/\x{ab}/\"/g; # opening guillemet + s/\x{bb}/\"/g; # closing guillemet + s/\x{0301}/'/g; # combining acute accent + s/\x{203a}/\"/g; # angle quotation mark + s/\x{2039}/\"/g; # angle quotation mark + + # Space inverted punctuation: + s/¡/ ¡ /g; + s/¿/ ¿ /g; + + # Russian abbreviations: + s/ п. п. / п.п. /g; + s/ ст. л. / ст.л. /g; + s/ т. е. / т.е. /g; + s/ т. к. / т.к. /g; + s/ т. ч. / т.ч. /g; + s/ т. д. / т.д. /g; + s/ т. п. / т.п. /g; + s/ и. о. / и.о. /g; + s/ с. г. / с.г. /g; + s/ г. р. / г.р. /g; + s/ т. н. / т.н. /g; + s/ т. ч. / т.ч. /g; + s/ н. э. / н.э. /g; + + # Convert foreign numerals into Arabic numerals + tr/०-९/0-9/; # devangari + tr/౦-౯/0-9/; # telugu + tr/೦-೯/0-9/; # kannada + tr/೦-௯/0-9/; # tamil + tr/൦-൯/0-9/; # malayalam + + # Random punctuation: + tr/!-~/!-~/; + s/、/,/g; + # s/。/./g; + s/\x{85}/.../g; + s/…/.../g; + s/―/--/g; + s/–/--/g; + s/─/--/g; + s/—/--/g; + s/\x{97}/--/g; + s/•/ * /g; + s/\*/ * /g; + s/،/,/g; + s/؟/?/g; + s/ـ/ /g; + s/à ̄/i/g; + s/’/'/g; + s/â€"/"/g; + s/؛/;/g; + + # Regularize ligatures: + s/\x{9c}/oe/g; # "oe" ligature + s/\x{0153}/oe/g; # "oe" ligature + s/\x{8c}/Oe/g; # "OE" ligature + s/\x{0152}/Oe/g; # "OE" ligature + s/\x{fb00}/ff/g; # "ff" ligature + s/\x{fb01}/fi/g; # "fi" ligature + s/\x{fb02}/fl/g; # "fl" ligature + s/\x{fb03}/ffi/g; # "ffi" ligature + s/\x{fb04}/ffi/g; # "ffl" ligature + + s/β/ß/g; # WMT 2010 error + + # Strip extra spaces: + s/\s+/ /g; + s/^\s+//; + s/\s+$//; + + print "$_\n"; +} + diff --git a/corpus/support/token_list b/corpus/support/token_list new file mode 100644 index 000000000..00daa82b8 --- /dev/null +++ b/corpus/support/token_list @@ -0,0 +1,558 @@ +##################### hyphenated words added by Fei since 3/7/05 +##X-ray + +# Finnish +eaa. +ap. +arv. +ay. +eKr. +em. +engl. +esim. +fil. +lis. +fil. +maist. +fil.toht. +harv. +ilt. +jatk. +jKr. +jms. +jne. +joht. +klo +ko. +ks. +leht. +lv. +lyh. +mm. +mon. +nim. +nro. +ns. +nti. +os. +oy. +pj. +pnä. +puh. +pvm. +rva. +tms. +ts. +vars. +vrt. +ym. +yms. +yo. +>>>>>>> 8646b68e5b124f612fd65b51ea40624f65a2f3d6 + +# hindi abbreviation patterns +जन. +फर. +अग. +सित. +अक्टू. +अक्तू. +नव. +दिस. +डी.एल. +डी.टी.ओ. +डी.ए. +ए.एस.आई. +डी.टी.ओ. +एम.एस.आर.टी.सी. +बी.बी.एम.बी. +डी.एस.पी. +सी.आर.पी. +एस.डी.एम. +सी.डी.पी.ओ. +बी.डी.ओ. +एस.डी.ओ. +एम.पी.पी. +पी.एच.ई. +एस.एच.ओ. +ए.सी.पी. +यू.पी. +पी.एम. +आर.बी.डी. +वी.पी. +सी.ए.डी.पी. +ए. +बी. +सी. +डी. +ई. +एफ. +जी. +एच. +आई. +जे. +के. +एल. +एम. +एन. +ओ. +पी. +क़यू. +आर. +एस. +टी. +यू. +वी. +डबल्यू. +एक्स. +वाई. +ज़ेड. +ज़ी. + +##################### words made of punct only +:- +:-) +:-( ++= +-= +.= +*= +>= +<= +== +&& +|| +=> +-> +<- +:) +:( +;) + +#################### abbr added by Fei +oz. +fl. +tel. +1. +2. +3. +4. +5. +6. +7. +8. +9. +10. + +##################### abbreviation: words that contain period. +EE.UU. +ee.uu. +U.A.E +Ala. +Ph.D. +min. +max. +z.B. +d.h. +ggf. +ca. +bzw. +bzgl. +Eng. +i.e. +a.m. +am. +A.M. +Apr. +Ariz. +Ark. +Aug. +B.A.T. +B.A.T +Calif. +Co. +Conn. +Corp. +Cos. +D.C. +Dec. +Dept. +Dr. +Drs. +Feb. +Fla. +Fri. +Ga. +Gen. +gen. +GEN. +Gov. +Govt. +Ill. +Inc. +Jan. +Jr. +Jul. +Jun. +Kan. +L.A. +Lieut. +Lt. +Ltd. +Ma. +Mar. +Mass. +Md. +Mfg. +Mgr. +Mio. +Mrd. +Bio. +Minn. +Mo. +Mon. +Mr. +Mrs. +Ms. +Mt. +N.D. +Neb. +Nev. +No. +Nos. +Nov. +Oct. +Okla. +Op. +Ore. +Pa. +p.m +p.m. +I.B.C. +N.T.V +Pres. +Prof. +Prop. +Rd. +Rev. +R.J. +C.L +Rs. +Rte. +Sat. +W.T +Sen. +Sep. +Sept. +Sgt. +Sr. +SR. +St. +Ste. +Sun. +Tenn. +Tex. +Thu. +Tue. +Univ. +Va. +Vt. +Wed. +approx. +dept. +e.g. +E.G. +eg. +est. +etc. +ex. +ext. +ft. +hon. +hr. +hrs. +lab. +lb. +lbs. +mass. +misc. +no. +nos. +nt. +para. +paras. +pct. +prod. +rec. +ref. +rel. +rep. +sq. +st. +stg. +vol. +vs. +U.S. +J.S. +U.N. +u.n. +A. +B. +C. +D. +E. +F. +G. +H. +I. +J. +K. +L. +M. +N. +O. +P. +Q. +R. +S. +T. +U. +V. +W. +X. +Y. +Z. +А. +Б. +В. +Г. +Д. +Е. +Ё. +Ж. +З. +И. +Й. +К. +Л. +М. +Н. +О. +П. +Р. +С. +Т. +У. +Ф. +Х. +Ц. +Ч. +Ш. +Щ. +Ъ. +Ы. +Ь. +Э. +Ю. +Я. +л. +г. +обл. +гг. +в. +вв. +мин. +ч. +тыс. +млн. +млрд. +трлн. +кв. +куб. +руб. +коп. +долл. +Прим. +прим. +чел. +грн. +мин. +им. +проф. +акад. +ред. +авт. +корр. +соб. +спец. +см. +тж. +др. +пр. +букв. +# Two-letter abbreviations - can be written with space +п.п. +ст.л. +т.е. +т.к. +т.ч. +т.д. +т.п. +и.о. +с.г. +г.р. +т.н. +т.ч. +н.э. +# Swahili +A.D. +Afr. +A.G. +agh. +A.H. +A.M. +a.s. +B.A. +B.C. +Bi. +B.J. +B.K. +B.O.M. +Brig. +Bro. +bt. +bw. +Bw. +Cap. +C.C. +cCM. +C.I.A. +cit. +C.M.S. +Co. +Corp. +C.S.Sp. +C.W. +D.C. +Dk. +Dkt. +Dk.B. +Dr. +E.C. +e.g. +E.M. +E.n. +etc. +Feb. +F.F.U. +F.M. +Fr. +F.W. +I.C.O. +i.e. +I.L.C. +Inc. +Jan. +J.F. +Jr. +J.S. +J.V.W.A. +K.A.R. +K.A.U. +K.C.M.C. +K.k. +K.K. +k.m. +km. +K.m. +K.N.C.U. +K.O. +K.S. +Ksh. +kt. +kumb. +k.v. +kv. +L.G. +ltd. +Ltd. +M.A. +M.D. +mf. +Mh. +Mhe. +mil. +m.m. +M.m. +Mm. +M.M. +Mr. +Mrs. +M.S. +Mt. +Mw. +M.W. +Mwl. +na. +Na. +N.F. +N.J. +n.k. +nk. +n.k.w. +N.N. +Nov. +O.C.D. +op. +P.C. +Phd. +Ph.D. +P.J. +P.o. +P.O. +P.O.P. +P.P.F. +Prof. +P.s. +P.S. +Q.C. +Rd. +s.a.w. +S.A.W. +S.D. +Sept. +sh. +Sh. +SH. +shs. +Shs. +S.J. +S.L. +S.L.P. +S.s. +S.S. +St. +s.w. +s.w.T. +taz. +Taz. +T.C. +T.E.C. +T.L.P. +T.O.H.S. +Tsh. +T.V. +tz. +uk. +Uk. +U.M.C.A. +U.N. +U.S. +Ush. +U.W.T. +Viii. +Vol. +V.T.C. +W.H. +yamb. +Y.M.C.A. diff --git a/corpus/support/token_patterns b/corpus/support/token_patterns new file mode 100644 index 000000000..12558cddb --- /dev/null +++ b/corpus/support/token_patterns @@ -0,0 +1,7 @@ +/^(al|el|ul|e)\-[a-z]+$/ +/\.(fi|fr|es|co\.uk|de)$/ +/:[a-zä]+$/ +/^((а|А)(ль|ш)|уль)-\p{Cyrillic}+$/ +/^\p{Cyrillic}\.\p{Cyrillic}\.$/ +/^(\d|\d\d|\d\d\d)\.$/ + diff --git a/corpus/support/tokenizer.pl b/corpus/support/tokenizer.pl new file mode 100755 index 000000000..6cc9f4e1e --- /dev/null +++ b/corpus/support/tokenizer.pl @@ -0,0 +1,712 @@ +#!/usr/bin/env perl +$|++; + +my $script_dir; +BEGIN {$^W = 1; use Cwd qw/ abs_path /; use File::Basename; $script_dir = dirname(abs_path($0)); push @INC, $script_dir; } + +use strict; +use utf8; + +binmode STDIN, ":utf8"; +binmode STDOUT, ":utf8"; +binmode STDERR, ":utf8"; + +my $debug = 0; + + +############ options: +### for all options: +### 0 means no split on that symbol +### 1 means split on that symbol in all cases. +### 2 means do not split in condition 1. +### n means do not split in any of the conditions in the set {1, 2, ..., n-1}. + + +### prefix +## for "#": #90 +my $Split_On_SharpSign = 2; # 2: do not split on Num, e.g., "#90" + + +############## "infix" +my $Split_On_Tilde = 2; # 2: do not split on Num, e.g., "12~13". + +my $Split_On_Circ = 2; # 2: do not split on Num, e.g, "2^3" + +## for "&" +my $Split_On_AndSign = 2; # 2: do not split on short Name, e.g., "AT&T". + +## for hyphen: 1990-1992 +my $Split_On_Dash = 2; ## 2: do not split on number, e.g., "22-23". +my $Split_On_Underscore = 0; ## 0: do not split by underline + +## for ":": 5:4 +my $Split_On_Semicolon = 2; ## 2: don't split for num, e.g., "5:4" + +########### suffix +## for percent sign: 5% +my $Split_On_PercentSign = 1; ## 2: don't split num, e.g., 5% + +############# others +## for slash: 1/4 +my $Split_On_Slash = 2; ## 2: don't split on number, e.g., 1/4. +my $Split_On_BackSlash = 0; ## 0: do not split on "\", e.g., \t + +### for "$": US$120 +my $Split_On_DollarSign = 2; ### 2: US$120 => "US$ 120" + ### 1: US$120 => "US $ 120" +## for 's etc. +my $Split_NAposT = 1; ## n't +my $Split_AposS = 1; ## 's +my $Split_AposM = 1; ## 'm +my $Split_AposRE = 1; ## 're +my $Split_AposVE = 1; ## 've +my $Split_AposLL = 1; ## 'll +my $Split_AposD = 1; ## 'd + + +### some patterns +my $common_right_punc = '\x{0964}|\.|\,|\;|\!|:|\?|\"|\)|\]|\}|\>|\-'; + +#### step 1: read files + +my $workdir = $script_dir; +my $dict_file = "$workdir/token_list"; +my $word_patt_file = "$workdir/token_patterns"; + +open(my $dict_fp, "$dict_file") or die; +binmode($dict_fp, ":utf8"); + +# read in the list of words that should not be segmented, +## e.g.,"I.B.M.", co-operation. +my %dict_hash = (); +my $dict_entry = 0; +while(<$dict_fp>){ + chomp; + next if /^\s*$/; + s/^\s+//; + s/\s+$//; + tr/A-Z/a-z/; + $dict_hash{$_} = 1; + $dict_entry ++; +} + +open(my $patt_fp, "$word_patt_file") or die; +binmode($patt_fp, ":utf8"); +my @word_patts = (); +my $word_patt_num = 0; +while(<$patt_fp>){ + chomp; + next if /^\s*$/; + s/^\s+//; + s/\s+$//; + s/^\/(.+)\/$/$1/; # remove / / around the pattern + push(@word_patts, $_); + $word_patt_num ++; +} + + +###### step 2: process the input file +my $orig_token_total = 0; +my $deep_proc_token_total = 0; +my $new_token_total = 0; + +while(){ + chomp(); + s/\x{0970}/./g; # dev abbreviation character + if(/^(\[b\s+|\]b|\]f|\[f\s+)/ || (/^\[[bf]$/) || (/^\s*$/) || /^//; + $new_line =~ s/\s*<\s+(p|hl)\s+>/<$1>/; + $new_line =~ s/\s*<\s+\/\s+(p|hl|DOC)\s+>/<\/$1>/; + $new_line =~ s/<\s+\/\s+seg\s+>/<\/seg>/; + if ($new_line =~ /^\s*<\s+DOC\s+/) { + $new_line =~ s/\s+//g; + $new_line =~ s/DOC/DOC /; + $new_line =~ s/sys/ sys/; + } + if ($new_line =~ /^\s*<\s+(refset|srcset)\s+/) { + $new_line =~ s/\s+//g; + $new_line =~ s/(set|src|tgt|trg)/ $1/g; + } + + chomp $new_line; + print STDOUT "$new_line\n"; +} + +######################################################################## + +### tokenize a line. +sub proc_line { + my @params = @_; + my $param_num = scalar @params; + + if(($param_num < 1) || ($param_num > 3)){ + die "wrong number of params for proc_line: $param_num\n"; + } + + my $orig_line = $params[0]; + + $orig_line =~ s/^\s+//; + $orig_line =~ s/\s+$//; + + my @parts = split(/\s+/, $orig_line); + + if($param_num >= 2){ + my $orig_num_ptr = $params[1]; + $$orig_num_ptr = scalar @parts; + } + + my $new_line = ""; + + my $deep_proc_token = 0; + foreach my $part (@parts){ + my $flag = -1; + $new_line .= proc_token($part, \$flag) . " "; + $deep_proc_token += $flag; + } + + if($param_num == 3){ + my $deep_num_ptr = $params[2]; + $$deep_num_ptr = $deep_proc_token; + } + + return $new_line; +} + + + +## Tokenize a str that does not contain " ", return the new string +## The function handles the cases that the token needs not be segmented. +## for other cases, it calls deep_proc_token() +sub proc_token { + my @params = @_; + my $param_num = scalar @params; + if($param_num > 2){ + die "proc_token: wrong number of params: $param_num\n"; + } + + my $token = $params[0]; + + if(!defined($token)){ + return ""; + } + + my $deep_proc_flag; + + if($param_num == 2){ + $deep_proc_flag = $params[1]; + $$deep_proc_flag = 0; + } + + if($debug){ + print STDERR "pro_token:+$token+\n"; + } + + ### step 0: it has only one char + if(($token eq "") || ($token=~ /^.$/)){ + ## print STDERR "see +$token+\n"; + return $token; + } + + ## step 1: check the most common case + if($token =~ /^[a-z0-9\p{Cyrillic}\p{Greek}\p{Hebrew}\p{Han}\p{Arabic}\p{Devanagari}]+$/i){ + #if($token =~ /^[a-z0-9\p{Cyrillic}\p{Greek}\p{Hebrew}\p{Han}\p{Arabic}]+$/i){ + ### most common cases + return $token; + } + + ## step 2: check whether it is some NE entity + ### 1.2.4.6 + if($token =~ /^\d+(.\d+)+$/){ + return $token; + } + + if($token =~ /^\d+(.\d+)+(亿|百万|万|千)?$/){ + return $token; + } + + ## 1,234,345.34 + if($token =~ /^\d+(\.\d{3})*,\d+$/){ + ## number + return $token; + } + if($token =~ /^\d+(,\d{3})*\.\d+$/){ + ## number + return $token; + } + if($token =~ /^(@|#)[A-Za-z0-9_\p{Cyrillic}\p{Greek}\p{Hebrew}\p{Han}\p{Arabic}\p{Devanagari}]+.*$/){ + ## twitter hashtag or address + return proc_rightpunc($token); + } + + if($token =~ /^[a-z0-9\_\-]+\@[a-z\d\_\-]+(\.[a-z\d\_\-]+)*(.*)$/i){ + ### email address: xxx@yy.zz + return proc_rightpunc($token); + } + + if($token =~ /^(mailto|http|https|ftp|gopher|telnet|file)\:\/{0,2}([^\.]+)(\.(.+))*$/i){ + ### URL: http://xx.yy.zz + return proc_rightpunc($token); + } + + if($token =~ /^(www)(\.(.+))+$/i){ + ### www.yy.dd/land/ + return proc_rightpunc($token); + } + + if($token =~ /^(\w+\.)+(com|co|edu|org|gov|ly|cz|ru|eu)(\.[a-z]{2,3})?\:{0,2}(\/\S*)?$/i){ + ### URL: upenn.edu/~xx + return proc_rightpunc($token); + } + + if($token =~ /^\(\d{3}\)\d{3}(\-\d{4})($common_right_punc)*$/){ + ## only handle American phone numbers: e.g., (914)244-4567 + return proc_rightpunc($token); + } + + #my $t1 = '[\x{0600}-\x{06ff}a-z\d\_\.\-]'; + my $t1 = '[a-z\d\_\-\.\p{Cyrillic}\p{Greek}\p{Hebrew}\p{Han}\p{Arabic}\p{Devanagari}]'; + if($token =~ /^\/(($t1)+\/)+($t1)+\/?$/i){ + ### /nls/p/.... + return $token; + } + + if($token =~ /^\\(($t1)+\\)+($t1)+\\?$/i){ + ### \nls\p\.... + return $token; + } + + ## step 3: check the dictionary + my $token_lc = $token; + $token_lc =~ tr/A-Z/a-z/; + + if(defined($dict_hash{$token_lc})){ + return $token; + } + + ## step 4: check word_patterns + my $i=1; + foreach my $patt (@word_patts){ + if($token_lc =~ /$patt/){ + if($debug){ + print STDERR "+$token+ match pattern $i: +$patt+\n"; + } + return $token; + }else{ + $i++; + } + } + + ## step 5: call deep tokenization + if($param_num == 2){ + $$deep_proc_flag = 1; + } + return deep_proc_token($token); +} + + +### remove punct on the right side +### e.g., xxx@yy.zz, => xxx@yy.zz , +sub proc_rightpunc { + my ($token) = @_; + + $token =~ s/(($common_right_punc)+)$/ $1 /; + if($token =~ /\s/){ + return proc_line($token); + }else{ + return $token; + } +} + + + +####################################### +### return the new token: +### types of punct: +## T1 (2): the punct is always a token by itself no matter where it +### appears: " ; +## T2 (15): the punct that can be a part of words made of puncts only. +## ` ! @ + = [ ] ( ) { } | < > ? +## T3 (15): the punct can be part of a word that contains [a-z\d] +## T3: ~ ^ & : , # * % - _ \ / . $ ' +## infix: ~ (12~13), ^ (2^3), & (AT&T), : , +## prefix: # (#9), * (*3), +## suffix: % (10%), +## infix+prefix: - (-5), _ (_foo), +## more than one position: \ / . $ +## Appos: 'm n't ... + +## 1. separate by puncts in T1 +## 2. separate by puncts in T2 +## 3. deal with punct T3 one by one according to options +## 4. if the token remains unchanged after step 1-3, return the token + +## $line contains at least 2 chars, and no space. +sub deep_proc_token { + my ($line) = @_; + if($debug){ + print STDERR "deep_proc_token: +$line+\n"; + } + + ##### step 0: if it mades up of all puncts, remove one punct at a time. + if($line !~ /[\p{Cyrillic}\p{Greek}\p{Hebrew}\p{Han}\p{Arabic}\p{Devanagari}a-zA-Z\d]/){ + if($line =~ /^(\!+|\@+|\++|\=+|\*+|\<+|\>+|\|+|\?+|\x{0964}+|\.+|\-+|\_+|\&+)$/){ + ## ++ @@@@ !!! .... + return $line; + } + + if($line =~ /^(.)(.+)$/){ + my $t1 = $1; + my $t2 = $2; + return $t1 . " " . proc_token($t2); + }else{ + ### one char only + print STDERR "deep_proc_token: this should not happen: +$line+\n"; + return $line; + } + } + + ##### step 1: separate by punct T2 on the boundary + my $t2 = '\`|\!|\@|\+|\=|\[|\]|\<|\>|\||\(|\)|\{|\}|\?|\"|;|●|○'; + if($line =~ s/^(($t2)+)/$1 /){ + $line =~ s/"/“/; + return proc_line($line); + } + + if($line =~ s/(($t2)+)$/ $1/){ + $line =~ s/"/”/; + return proc_line($line); + } + + ## step 2: separate by punct T2 in any position + if($line =~ s/(($t2)+)/ $1 /g){ + $line =~ s/"/”/g; # probably before punctuation char + return proc_line($line); + } + + ##### step 3: deal with special puncts in T3. + if($line =~ /^(\,+)(.+)$/){ + my $t1 = $1; + my $t2 = $2; + return proc_token($t1) . " " . proc_token($t2); + } + + if($line =~ /^(.*[^\,]+)(\,+)$/){ + ## 19.3,,, => 19.3 ,,, + my $t1 = $1; + my $t2 = $2; + return proc_token($t1) . " " . proc_token($t2); + } + + ## remove the ending periods that follow number etc. + if($line =~ /^(.*(\d|\~|\^|\&|\:|\,|\#|\*|\%|€|\-|\_|\/|\\|\$|\'))(\.+)$/){ + ## 12~13. => 12~13 . + my $t1 = $1; + my $t3 = $3; + return proc_token($t1) . " " . proc_token($t3); + } + + ### deal with "$" + if(($line =~ /\$/) && ($Split_On_DollarSign > 0)){ + my $suc = 0; + if($Split_On_DollarSign == 1){ + ## split on all occasation + $suc = ($line =~ s/(\$+)/ $1 /g); + }else{ + ## split only between $ and number + $suc = ($line =~ s/(\$+)(\d)/$1 $2/g); + } + + if($suc){ + return proc_line($line); + } + } + + ## deal with "#" + if(($line =~ /\#/) && ($Split_On_SharpSign > 0)){ + my $suc = 0; + if($Split_On_SharpSign >= 2){ + ### keep #50 as a token + $suc = ($line =~ s/(\#+)(\D)/ $1 $2/gi); + }else{ + $suc = ($line =~ s/(\#+)/ $1 /gi); + } + + if($suc){ + return proc_line($line); + } + } + + ## deal with ' + if($line =~ /\'/){ + my $suc = ($line =~ s/([^\'])([\']+)$/$1 $2/g); ## xxx'' => xxx '' + + ### deal with ': e.g., 's, 't, 'm, 'll, 're, 've, n't + + ## 'there => ' there '98 => the same + $suc += ($line =~ s/^(\'+)([a-z\p{Cyrillic}\p{Greek}\p{Hebrew}\p{Han}\p{Arabic}\p{Devanagari}]+)/ $1 $2/gi); + + ## note that \' and \. could interact: e.g., U.S.'s; 're. + if($Split_NAposT && ($line =~ /^(.*[a-z]+)(n\'t)([\.]*)$/i)){ + ## doesn't => does n't + my $t1 = $1; + my $t2 = $2; + my $t3 = $3; + return proc_token($t1) . " " . $t2 . " " . proc_token($t3); + } + + ## 's, 't, 'm, 'll, 're, 've: they've => they 've + ## 1950's => 1950 's Co.'s => Co. 's + if($Split_AposS && ($line =~ /^(.+)(\'s)(\W*)$/i)){ + my $t1 = $1; + my $t2 = $2; + my $t3 = $3; + return proc_token($t1) . " " . $t2 . " " . proc_token($t3); + } + + if($Split_AposM && ($line =~ /^(.*[a-z]+)(\'m)(\.*)$/i)){ + my $t1 = $1; + my $t2 = $2; + my $t3 = $3; + return proc_token($t1) . " " . $t2 . " " . proc_token($t3); + } + + + if($Split_AposRE && ($line =~ /^(.*[a-z]+)(\'re)(\.*)$/i)){ + my $t1 = $1; + my $t2 = $2; + my $t3 = $3; + return proc_token($t1) . " " . $t2 . " " . proc_token($t3); + } + + if($Split_AposVE && ($line =~ /^(.*[a-z]+)(\'ve)(\.*)$/i)){ + my $t1 = $1; + my $t2 = $2; + my $t3 = $3; + return proc_token($t1) . " " . $t2 . " " . proc_token($t3); + } + + if($Split_AposLL && ($line =~ /^(.*[a-z]+)(\'ll)(\.*)$/i)){ + my $t1 = $1; + my $t2 = $2; + my $t3 = $3; + return proc_token($t1) . " " . $t2 . " " . proc_token($t3); + } + + if($Split_AposD && ($line =~ /^(.*[a-z]+)(\'d)(\.*)$/i)){ + my $t1 = $1; + my $t2 = $2; + my $t3 = $3; + return proc_token($t1) . " " . $t2 . " " . proc_token($t3); + } + + if($suc){ + return proc_line($line); + } + } + + + ## deal with "~" + if(($line =~ /\~/) && ($Split_On_Tilde > 0)){ + my $suc = 0; + if($Split_On_Tilde >= 2){ + ## keep 12~13 as one token + $suc += ($line =~ s/(\D)(\~+)/$1 $2 /g); + $suc += ($line =~ s/(\~+)(\D)/ $1 $2/g); + $suc += ($line =~ s/^(\~+)(\d)/$1 $2/g); + $suc += ($line =~ s/(\d)(\~+)$/$1 $2/g); + }else{ + $suc += ($line =~ s/(\~+)/ $1 /g); + } + if($suc){ + return proc_line($line); + } + } + + ## deal with "^" + if(($line =~ /\^/) && ($Split_On_Circ > 0)){ + my $suc = 0; + if($Split_On_Circ >= 2){ + ## keep 12~13 as one token + $suc += ($line =~ s/(\D)(\^+)/$1 $2 /g); + $suc += ($line =~ s/(\^+)(\D)/ $1 $2/g); + }else{ + $suc = ($line =~ s/(\^+)/ $1 /g); + } + if($suc){ + return proc_line($line); + } + } + + ## deal with ":" + if(($line =~ /\:/) && ($Split_On_Semicolon > 0)){ + ## 2: => 2 : + my $suc = ($line =~ s/^(\:+)/$1 /); + $suc += ($line =~ s/(\:+)$/ $1/); + if($Split_On_Semicolon >= 2){ + ## keep 5:4 as one token + $suc += ($line =~ s/(\D)(\:+)/$1 $2 /g); + $suc += ($line =~ s/(\:+)(\D)/ $1 $2/g); + }else{ + $suc += ($line =~ s/(\:+)/ $1 /g); + } + + if($suc){ + return proc_line($line); + } + } + + ### deal with hyphen: 1992-1993. 21st-24th + if(($line =~ /\-/) && ($Split_On_Dash > 0)){ + my $suc = ($line =~ s/(\-{2,})/ $1 /g); + if($Split_On_Dash >= 2){ + ## keep 1992-1993 as one token + $suc += ($line =~ s/(\D)(\-+)/$1 $2 /g); + $suc += ($line =~ s/(\-+)(\D)/ $1 $2/g); + }else{ + ### always split on "-" + $suc += ($line =~ s/([\-]+)/ $1 /g); + } + + if($suc){ + return proc_line($line); + } + } + + ## deal with "_" + if(($line =~ /\_/) && ($Split_On_Underscore > 0)){ + ### always split on "-" + if($line =~ s/([\_]+)/ $1 /g){ + return proc_line($line); + } + } + + + + ## deal with "%" + if(($line =~ /\%|€/) && ($Split_On_PercentSign > 0)){ + my $suc = 0; + if($Split_On_PercentSign >= 2){ + $suc += ($line =~ s/(\D)(\%+|€+)/$1 $2/g); + }else{ + $suc += ($line =~ s/(\%+|€+)/ $1 /g); + } + + if($suc){ + return proc_line($line); + } + } + + + ### deal with "/": 4/5 + if(($line =~ /\//) && ($Split_On_Slash > 0)){ + my $suc = 0; + if($Split_On_Slash >= 2){ + $suc += ($line =~ s/(\D)(\/+)/$1 $2 /g); + $suc += ($line =~ s/(\/+)(\D)/ $1 $2/g); + }else{ + $suc += ($line =~ s/(\/+)/ $1 /g); + } + + if($suc){ + return proc_line($line); + } + } + + + ### deal with comma: 123,456 + if($line =~ /\,/){ + my $suc = 0; + $suc += ($line =~ s/([^\d]),/$1 , /g); ## xxx, 1923 => xxx , 1923 + $suc += ($line =~ s/\,\s*([^\d])/ , $1/g); ## 1923, xxx => 1923 , xxx + + $suc += ($line =~ s/,([\d]{1,2}[^\d])/ , $1/g); ## 1,23 => 1 , 23 + $suc += ($line =~ s/,([\d]{4,}[^\d])/ , $1/g); ## 1,2345 => 1 , 2345 + + $suc += ($line =~ s/,([\d]{1,2})$/ , $1/g); ## 1,23 => 1 , 23 + $suc += ($line =~ s/,([\d]{4,})$/ , $1/g); ## 1,2345 => 1 , 2345 + + if($suc){ + return proc_line($line); + } + } + + + ## deal with "&" + if(($line =~ /\&/) && ($Split_On_AndSign > 0)){ + my $suc = 0; + if($Split_On_AndSign >= 2){ + $suc += ($line =~ s/([a-z]{3,})(\&+)/$1 $2 /gi); + $suc += ($line =~ s/(\&+)([a-z]{3,})/ $1 $2/gi); + }else{ + $suc += ($line =~ s/(\&+)/ $1 /g); + } + + if($suc){ + return proc_line($line); + } + } + + ## deal with period + if($line =~ /\./){ + if($line =~ /^(([\+|\-])*(\d+\,)*\d*\.\d+\%*)$/){ + ### numbers: 3.5 + return $line; + } + + if ($line =~ /^(([a-z]|ए|बी|सी|डी|ई|एफ|जी|एच|आई|जे|के|एल|एम|एन|ओ|पी|क़यू|आर|एस|टी|यू|वी|डबल्यू|एक्स|वाई|ज़ेड|ज़ी)(\.([a-z]|ए|बी|सी|डी|ई|एफ|जी|एच|आई|जे|के|एल|एम|एन|ओ|पी|क़यू|आर|एस|टी|यू|वी|डबल्यू|एक्स|वाई|ज़ेड|ज़ी))+)(\.?)(\.*)$/i){ + ## I.B.M. + my $t1 = $1 . $5; + my $t3 = $6; + return $t1 . " ". proc_token($t3); + } + + ## Feb.. => Feb. . + if($line =~ /^(.*[^\.])(\.)(\.*)$/){ + my $p1 = $1; + my $p2 = $2; + my $p3 = $3; + + my $p1_lc = $p1; + $p1_lc =~ tr/A-Z/a-z/; + + if(defined($dict_hash{$p1_lc . $p2})){ + ## Dec.. => Dec. . + return $p1 . $p2 . " " . proc_token($p3); + }elsif(defined($dict_hash{$p1_lc})){ + return $p1 . " " . proc_token($p2 . $p3); + }else{ + ## this. => this . + return proc_token($p1) . " " . proc_token($p2 . $p3); + } + } + + if($line =~ s/(\.+)(.+)/$1 $2/g){ + return proc_line($line); + } + } + + + ## no pattern applies + return $line; +} + diff --git a/corpus/support/utf8-normalize-batch.pl b/corpus/support/utf8-normalize-batch.pl new file mode 100755 index 000000000..e574f861a --- /dev/null +++ b/corpus/support/utf8-normalize-batch.pl @@ -0,0 +1,28 @@ +#!/usr/bin/env perl + +use IPC::Open2; + +$|++; + +if (scalar(@ARGV) != 1) { + print STDERR "usage: $0 \"CMD\"\n"; + exit(2); +} + +$CMD = $ARGV[0]; + +while () { + s/\r\n*/\n/g; + $PID = open2(*SOUT, *SIN, $CMD); + print SIN "$_\n"; + close(SIN); + $_ = ; + close(SOUT); + waitpid($PID, 0); + chomp; + s/[\x00-\x1F]+/ /g; + s/ +/ /g; + s/^ //; + s/ $//; + print "$_\n"; +} diff --git a/corpus/support/utf8-normalize.sh b/corpus/support/utf8-normalize.sh new file mode 100755 index 000000000..af9895ba0 --- /dev/null +++ b/corpus/support/utf8-normalize.sh @@ -0,0 +1,43 @@ +#!/bin/bash + +# this is the location on malbec, if you want to run on another machine +# ICU may be installed in /usr or /usr/local +ICU_DIR=/usr0/tools/icu +UCONV_BIN=$ICU_DIR/bin/uconv +UCONV_LIB=$ICU_DIR/lib + +if [ -e $UCONV_BIN ] && [ -d $UCONV_LIB ] +then + export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$UCONV_LIB + if [ ! -x $UCONV_BIN ] + then + echo "$0: Cannot execute $UCONV_BIN! Please fix." 1>&2 + exit + fi + CMD="$UCONV_BIN -f utf8 -t utf8 -x Any-NFKC --callback skip" +else + if which uconv > /dev/null + then + CMD="uconv -f utf8 -t utf8 -x Any-NFKC --callback skip" + else + echo "$0: Cannot find ICU uconv (http://site.icu-project.org/) ... falling back to iconv. Quality may suffer." 1>&2 + CMD="iconv -f utf8 -t utf8 -c" + fi +fi + +if [[ $# == 1 && $1 == "--batchline" ]]; then + perl $(dirname $0)/utf8-normalize-batch.pl "$CMD" +else + perl -e '$|++; while(<>){s/\r\n*/\n/g; print;}' \ + |$CMD \ + |/usr/bin/perl -w -e ' + $|++; + while (<>) { + chomp; + s/[\x00-\x1F]+/ /g; + s/ +/ /g; + s/^ //; + s/ $//; + print "$_\n"; + }' +fi diff --git a/corpus/tokenize-anything.sh b/corpus/tokenize-anything.sh new file mode 100755 index 000000000..c580e88b1 --- /dev/null +++ b/corpus/tokenize-anything.sh @@ -0,0 +1,29 @@ +#!/usr/bin/env bash + +ROOTDIR=`dirname $0` +SUPPORT=$ROOTDIR/support + +if [[ $# == 1 && $1 == '-u' ]] ; then + NORMARGS="--batchline" + SEDFLAGS="-u" +else + if [[ $# != 0 ]] ; then + echo Usage: `basename $0` [-u] \< file.in \> file.out 1>&2 + echo 1>&2 + echo Tokenizes text in a reasonable way in most languages. 1>&2 + echo 1>&2 + exit 1 + fi + NORMARGS="" + SEDFLAGS="" +fi + +$SUPPORT/utf8-normalize.sh $NORMARGS | + $SUPPORT/quote-norm.pl | + $SUPPORT/tokenizer.pl | + $SUPPORT/fix-eos.pl | + sed $SEDFLAGS -e 's/ al - / al-/g' | + $SUPPORT/fix-contract.pl | + sed $SEDFLAGS -e 's/^ //' | sed $SEDFLAGS -e 's/ $//' | + perl -e '$|++; while(<>){s/(\d+)(\.+)$/$1 ./; s/(\d+)(\.+) \|\|\|/$1 . |||/; print;}' + diff --git a/corpus/tokenize-parallel.py b/corpus/tokenize-parallel.py new file mode 100755 index 000000000..6e4d0bd83 --- /dev/null +++ b/corpus/tokenize-parallel.py @@ -0,0 +1,73 @@ +#!/usr/bin/env python + +import gzip +import math +import os +import shutil +import subprocess +import sys +import tempfile + +DEFAULT_JOBS = 8 +DEFAULT_TMP = '/tmp' + +TOKENIZER = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'tokenize-anything.sh') + +def gzopen(f): + return gzip.open(f) if f.endswith('.gz') else open(f) + +def wc(f): + return sum(1 for line in gzopen(f)) + +def main(argv): + + if len(argv[1:]) < 1: + sys.stderr.write('Parallelize text normalization with multiple instances of tokenize-anything.sh\n\n') + sys.stderr.write('Usage: {} in-file [jobs [temp-dir]] >out-file\n'.format(argv[0])) + sys.exit(2) + + in_file = argv[1] + jobs = int(argv[2]) if len(argv[1:]) > 1 else DEFAULT_JOBS + tmp = argv[3] if len(argv[1:]) > 2 else DEFAULT_TMP + + work = tempfile.mkdtemp(prefix='tok.', dir=tmp) + in_wc = wc(in_file) + # Don't start more jobs than we have lines + jobs = min(jobs, in_wc) + lines_per = int(math.ceil(float(in_wc)/jobs)) + + inp = gzopen(in_file) + procs = [] + files = [] + outs = [] + for i in range(jobs): + raw = os.path.join(work, 'in.{}'.format(i)) + tok = os.path.join(work, 'out.{}'.format(i)) + files.append(tok) + # Write raw batch + raw_out = open(raw, 'w') + for _ in range(lines_per): + line = inp.readline() + if not line: + break + raw_out.write(line) + raw_out.close() + # Start tokenizer + raw_in = open(raw) + tok_out = open(tok, 'w') + outs.append(tok_out) + p = subprocess.Popen(TOKENIZER, stdin=raw_in, stdout=tok_out) + procs.append(p) + + # Cat output of each tokenizer as it finishes + for (p, f, o) in zip(procs, files, outs): + p.wait() + o.close() + for line in open(f): + sys.stdout.write(line) + + # Cleanup + shutil.rmtree(work) + +if __name__ == '__main__': + main(sys.argv) diff --git a/corpus/untok.pl b/corpus/untok.pl new file mode 100755 index 000000000..723e78cbe --- /dev/null +++ b/corpus/untok.pl @@ -0,0 +1,63 @@ +#!/usr/bin/perl -w + +use IO::Handle; +STDOUT->autoflush(1); + +while (<>) { + $output = ""; + @tokens = split; + $lspace = 0; + $qflag = 0; + for ($i=0; $i<=$#tokens; $i++) { + $token = $tokens[$i]; + $prev = $next = ""; + $rspace = 1; + if ($i > 0) { + $prev = $tokens[$i-1]; + } + if ($i < $#tokens) { + $next = $tokens[$i+1]; + } + + # possessives join to the left + if ($token =~ /^(n't|'(s|m|re|ll|ve|d))$/) { + $lspace = 0; + } elsif ($token eq "'" && $prev =~ /s$/) { + $lspace = 0; + + # hyphen only when a hyphen, not a dash + } elsif ($token eq "-" && $prev =~ /[A-Za-z0-9]$/ && $next =~ /^[A-Za-z0-9]/) { + $lspace = $rspace = 0; + + # quote marks alternate + } elsif ($token eq '"') { + if ($qflag) { + $lspace = 0; + } else { + $rspace = 0; + } + $qflag = !$qflag; + + # period joins on both sides when a decimal point + } elsif ($token eq "." && $prev =~ /\d$/ && $next =~ /\d$/) { + $lspace = $rspace = 0; + + # Left joiners + } elsif ($token =~ /^[.,:;?!%)\]]$/) { + $lspace = 0; + # Right joiners + } elsif ($token =~ /^[$(\[]$/) { + $rspace = 0; + # Joiners on both sides + } elsif ($token =~ /^[\/]$/) { + $lspace = $rspace = 0; + } + + if ($lspace) { + $output .= " "; + } + $output .= $token; + $lspace = $rspace; + } + print "$output\n"; +} diff --git a/corpus/utf8-normalize.sh b/corpus/utf8-normalize.sh new file mode 100755 index 000000000..7c0db611a --- /dev/null +++ b/corpus/utf8-normalize.sh @@ -0,0 +1,25 @@ +#!/bin/bash + +# This script uses ICU uconv (http://site.icu-project.org/), if it's available +# to normalize UTF8 text into a standard form. For information about this +# process, refer to http://en.wikipedia.org/wiki/Unicode_equivalence#Normalization +# Escape characters between 0x00-0x1F are removed + +if which uconv > /dev/null +then + CMD="uconv -f utf8 -t utf8 -x Any-NFKC --callback skip --remove-signature" +else + echo "Cannot find ICU uconv (http://site.icu-project.org/) ... falling back to iconv. Normalization NOT taking place." 1>&2 + CMD="iconv -f utf8 -t utf8 -c" +fi + +$CMD | /usr/bin/perl -w -e ' + while (<>) { + chomp; + s/[\x00-\x1F]+/ /g; + s/ +/ /g; + s/^ //; + s/ $//; + print "$_\n"; + }' + diff --git a/corpus/xml-tok.py b/corpus/xml-tok.py new file mode 100755 index 000000000..4357ced63 --- /dev/null +++ b/corpus/xml-tok.py @@ -0,0 +1,47 @@ +#!/usr/bin/env python + +import os +import re +import subprocess +import sys + +# Tokenize XML files with tokenize-anything.sh +# in: The earnings on its 10-year bonds are 28.45%. +# out: The earnings on its 10 - year bonds are 28.45 % . + +def escape(s): + return s.replace('&', '&').replace('>', '>').replace('<', '<').replace('"', '"').replace('\'', ''') + +def unescape(s): + return s.replace('>', '>').replace('<', '<').replace('"', '"').replace(''', '\'').replace('&', '&') + +def main(): + tok = subprocess.Popen([os.path.join(os.path.dirname(__file__), 'tokenize-anything.sh'), '-u'], stdin=subprocess.PIPE, stdout=subprocess.PIPE) + while True: + line = sys.stdin.readline() + if not line: + break + line = line.strip() + pieces = [] + eol = len(line) + pos = 0 + while pos < eol: + next = line.find('<', pos) + if next == -1: + next = eol + tok.stdin.write('{}\n'.format(unescape(line[pos:next]))) + pieces.append(escape(tok.stdout.readline().strip())) + if next == eol: + break + pos = line.find('>', next + 1) + if pos == -1: + pos = eol + else: + pos += 1 + pieces.append(line[next:pos]) + sys.stdout.write('{}\n'.format(' '.join(pieces).strip())) + tok.stdin.close() + tok.wait() + +if __name__ == '__main__': + main() diff --git a/decoder/CMakeLists.txt b/decoder/CMakeLists.txt new file mode 100644 index 000000000..07d85b9a9 --- /dev/null +++ b/decoder/CMakeLists.txt @@ -0,0 +1,184 @@ +INCLUDE_DIRECTORIES(${CMAKE_CURRENT_SOURCE_DIR}) +INCLUDE_DIRECTORIES(${CMAKE_CURRENT_SOURCE_DIR}/../utils) +INCLUDE_DIRECTORIES(${CMAKE_CURRENT_SOURCE_DIR}/../mteval) +INCLUDE_DIRECTORIES(${CMAKE_CURRENT_SOURCE_DIR}/../klm) +INCLUDE_DIRECTORIES(${CMAKE_CURRENT_SOURCE_DIR}/..) + +PROJECT(decoder CXX) + +if (CMAKE_VERSION VERSION_LESS 2.8.9) # TODO remove once we increase the cmake requirement + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fPIC -DPIC") +else() + set(CMAKE_POSITION_INDEPENDENT_CODE ON) +endif() + +find_package(FLEX REQUIRED) + +FLEX_TARGET(RuleLexer rule_lexer.ll ${CMAKE_CURRENT_BINARY_DIR}/rule_lexer.cc) + +set(libcdec_SRCS + aligner.h + apply_models.h + bottom_up_parser.h + bottom_up_parser-rs.h + csplit.h + decoder.h + earley_composer.h + factored_lexicon_helper.h + ff.h + ff_basic.h + ff_bleu.h + ff_charset.h + ff_conll.h + ff_const_reorder_common.h + ff_const_reorder.h + ff_context.h + ff_csplit.h + ff_external.h + ff_factory.h + ff_klm.h + ff_lexical.h + ff_lm.h + ff_ngrams.h + ff_parse_match.h + ff_register.h + ff_rules.h + ff_ruleshape.h + ff_sample_fsa.h + ff_soft_syn.h + ff_soft_syntax.h + ff_soft_syntax_mindist.h + ff_source_path.h + ff_source_syntax.h + ff_source_syntax2.h + ff_spans.h + ff_tagger.h + ff_wordalign.h + ff_wordset.h + ffset.h + forest_writer.h + freqdict.h + grammar.h + hg.h + hg_intersect.h + hg_io.h + hg_remove_eps.h + hg_sampler.h + hg_test.h + hg_union.h + incremental.h + inside_outside.h + kbest.h + lattice.h + lexalign.h + lextrans.h + nt_span.h + oracle_bleu.h + phrasebased_translator.h + phrasetable_fst.h + program_options.h + rule_lexer.h + sentence_metadata.h + sentences.h + tagger.h + translator.h + trule.h + viterbi.h + aligner.cc + apply_models.cc + bottom_up_parser.cc + bottom_up_parser-rs.cc + cdec_ff.cc + csplit.cc + decoder.cc + earley_composer.cc + factored_lexicon_helper.cc + ff.cc + ff_basic.cc + ff_bleu.cc + ff_charset.cc + ff_conll.cc + ff_context.cc + ff_const_reorder.cc + ff_csplit.cc + ff_external.cc + ff_factory.cc + ff_klm.cc + ff_lm.cc + ff_ngrams.cc + ff_parse_match.cc + ff_rules.cc + ff_ruleshape.cc + ff_soft_syn.cc + ff_soft_syntax.cc + ff_soft_syntax_mindist.cc + ff_source_path.cc + ff_source_syntax.cc + ff_source_syntax2.cc + ff_spans.cc + ff_tagger.cc + ff_wordalign.cc + ff_wordset.cc + ffset.cc + forest_writer.cc + fst_translator.cc + tree2string_translator.cc + grammar.cc + hg.cc + hg_intersect.cc + hg_io.cc + hg_remove_eps.cc + hg_sampler.cc + hg_union.cc + incremental.cc + lattice.cc + lexalign.cc + lextrans.cc + node_state_hash.h + tree_fragment.cc + tree_fragment.h + maxtrans_blunsom.cc + phrasebased_translator.cc + phrasetable_fst.cc + rescore_translator.cc + ${FLEX_RuleLexer_OUTPUTS} + scfg_translator.cc + tagger.cc + translator.cc + trule.cc + viterbi.cc) + +add_library(libcdec STATIC ${libcdec_SRCS}) + +set(cdec_SRCS cdec.cc) +add_executable(cdec ${cdec_SRCS}) +target_link_libraries(cdec libcdec mteval utils ksearch klm klm_util klm_util_double ${Boost_LIBRARIES} ${ZLIB_LIBRARIES} ${BZIP2_LIBRARIES} ${LIBLZMA_LIBRARIES} ${LIBDL_LIBRARIES}) + +set(TEST_SRCS + grammar_test.cc + hg_test.cc + parser_test.cc + t2s_test.cc + trule_test.cc) + +foreach(testSrc ${TEST_SRCS}) + #Extract the filename without an extension (NAME_WE) + get_filename_component(testName ${testSrc} NAME_WE) + + #Add compile target + set_source_files_properties(${testSrc} PROPERTIES COMPILE_FLAGS "-DBOOST_TEST_DYN_LINK -DTEST_DATA=\\\"test_data/\\\"") + add_executable(${testName} ${testSrc}) + + #link to Boost libraries AND your targets and dependencies + target_link_libraries(${testName} libcdec mteval utils ksearch klm klm_util klm_util_double ${Boost_LIBRARIES} ${ZLIB_LIBRARIES} ${BZIP2_LIBRARIES} ${LIBLZMA_LIBRARIES}) + + #I like to move testing binaries into a testBin directory + set_target_properties(${testName} PROPERTIES + RUNTIME_OUTPUT_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}) + + #Finally add it to test execution - + #Notice the WORKING_DIRECTORY and COMMAND + add_test(NAME ${testName} COMMAND ${CMAKE_CURRENT_SOURCE_DIR}/${testName} + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}) +endforeach(testSrc) + diff --git a/decoder/aligner.cc b/decoder/aligner.cc new file mode 100644 index 000000000..fd6483702 --- /dev/null +++ b/decoder/aligner.cc @@ -0,0 +1,306 @@ +#include "aligner.h" + +#include +#include + +#include + +#include "array2d.h" +#include "hg.h" +#include "kbest.h" +#include "sentence_metadata.h" +#include "inside_outside.h" +#include "viterbi.h" +#include "alignment_io.h" + +using namespace std; + +// used with lexical models since they may not fully generate the +// source string +void SourceEdgeCoveragesUsingParseIndices(const Hypergraph& g, + vector >* src_cov) { + src_cov->clear(); + src_cov->resize(g.edges_.size()); + + for (int i = 0; i < g.edges_.size(); ++i) { + const Hypergraph::Edge& edge = g.edges_[i]; + set& cov = (*src_cov)[i]; + // no words + if (edge.rule_->EWords() == 0 || edge.rule_->FWords() == 0) + continue; + // aligned to NULL (crf ibm variant only) + if (edge.prev_i_ == -1 || edge.i_ == -1) { + cov.insert(-1); + continue; + } + assert(edge.j_ >= 0); + assert(edge.prev_j_ >= 0); + if (edge.Arity() == 0) { + for (int k = edge.prev_i_; k < edge.prev_j_; ++k) + cov.insert(k); + } else { + // note: this code, which handles mixed NT and terminal + // rules assumes that nodes uniquely define a src and trg + // span. + int k = edge.prev_i_; + int j = 0; + const vector& f = edge.rule_->e(); // rules are inverted + while (k < edge.prev_j_) { + if (f[j] > 0) { + cov.insert(k); + // cerr << "src: " << k << endl; + ++k; + ++j; + } else { + const Hypergraph::Node& tailnode = g.nodes_[edge.tail_nodes_[-f[j]]]; + assert(tailnode.in_edges_.size() > 0); + // any edge will do: + const Hypergraph::Edge& rep_edge = g.edges_[tailnode.in_edges_.front()]; + //cerr << "skip " << (rep_edge.prev_j_ - rep_edge.prev_i_) << endl; // src span + k += (rep_edge.prev_j_ - rep_edge.prev_i_); // src span + ++j; + } + } + } + } +} + +int SourceEdgeCoveragesUsingTree(const Hypergraph& g, + int node_id, + int span_start, + vector* spans, + vector >* src_cov) { + const Hypergraph::Node& node = g.nodes_[node_id]; + int k = -1; + for (int i = 0; i < node.in_edges_.size(); ++i) { + const int edge_id = node.in_edges_[i]; + const Hypergraph::Edge& edge = g.edges_[edge_id]; + set& cov = (*src_cov)[edge_id]; + const vector& f = edge.rule_->e(); // rules are inverted + int j = 0; + k = span_start; + while (j < f.size()) { + if (f[j] > 0) { + cov.insert(k); + ++k; + ++j; + } else { + const int tail_node_id = edge.tail_nodes_[-f[j]]; + int &right_edge = (*spans)[tail_node_id]; + if (right_edge < 0) + right_edge = SourceEdgeCoveragesUsingTree(g, tail_node_id, k, spans, src_cov); + k = right_edge; + ++j; + } + } + } + return k; +} + +void SourceEdgeCoveragesUsingTree(const Hypergraph& g, + vector >* src_cov) { + src_cov->clear(); + src_cov->resize(g.edges_.size()); + vector span_sizes(g.nodes_.size(), -1); + SourceEdgeCoveragesUsingTree(g, g.nodes_.size() - 1, 0, &span_sizes, src_cov); +} + +int TargetEdgeCoveragesUsingTree(const Hypergraph& g, + int node_id, + int span_start, + vector* spans, + vector >* trg_cov) { + const Hypergraph::Node& node = g.nodes_[node_id]; + int k = -1; + for (int i = 0; i < node.in_edges_.size(); ++i) { + const int edge_id = node.in_edges_[i]; + const Hypergraph::Edge& edge = g.edges_[edge_id]; + set& cov = (*trg_cov)[edge_id]; + int ntc = 0; + const vector& e = edge.rule_->f(); // rules are inverted + int j = 0; + k = span_start; + while (j < e.size()) { + if (e[j] > 0) { + cov.insert(k); + ++k; + ++j; + } else { + const int tail_node_id = edge.tail_nodes_[ntc]; + ++ntc; + int &right_edge = (*spans)[tail_node_id]; + if (right_edge < 0) + right_edge = TargetEdgeCoveragesUsingTree(g, tail_node_id, k, spans, trg_cov); + k = right_edge; + ++j; + } + } + // cerr << "node=" << node_id << ": k=" << k << endl; + } + return k; +} + +void TargetEdgeCoveragesUsingTree(const Hypergraph& g, + vector >* trg_cov) { + trg_cov->clear(); + trg_cov->resize(g.edges_.size()); + vector span_sizes(g.nodes_.size(), -1); + TargetEdgeCoveragesUsingTree(g, g.nodes_.size() - 1, 0, &span_sizes, trg_cov); +} + +struct TransitionEventWeightFunction { + typedef SparseVector Result; + inline SparseVector operator()(const Hypergraph::Edge& e) const { + SparseVector result; + result.set_value(e.id_, e.edge_prob_); + return result; + } +}; + +inline void WriteProbGrid(const Array2D& m, ostream* pos) { + ostream& os = *pos; + char b[1024]; + for (int i=0; i* edges) { + bool fix_up_src_spans = false; + if (k_best > 1 && edges) { + cerr << "ERROR: cannot request multiple best alignments and provide an edge set!\n"; + abort(); + } + if (map_instead_of_viterbi) { + if (k_best != 0) { + cerr << "WARNING: K-best alignment extraction not available for MAP, use --aligner_use_viterbi\n"; + } + k_best = 1; + } else { + if (k_best == 0) k_best = 1; + } + const Hypergraph* g = &in_g; + HypergraphP new_hg; + if (!IsSentence(src_lattice) || + !IsSentence(trg_lattice)) { + if (map_instead_of_viterbi) { + cerr << " Lattice alignment: using Viterbi instead of MAP alignment\n"; + } + map_instead_of_viterbi = false; + fix_up_src_spans = !IsSentence(src_lattice); + } + + KBest::KBestDerivations, ViterbiPathTraversal> kbest(in_g, k_best); + boost::scoped_ptr > kbest_edges; + + for (int best = 0; best < k_best; ++best) { + const KBest::KBestDerivations, ViterbiPathTraversal>::Derivation* d = NULL; + if (!map_instead_of_viterbi) { + d = kbest.LazyKthBest(in_g.nodes_.size() - 1, best); + if (!d) break; // there are fewer than k_best derivations! + const vector& yield = d->yield; + kbest_edges.reset(new vector(in_g.edges_.size(), false)); + for (int i = 0; i < yield.size(); ++i) { + assert(yield[i]->id_ < kbest_edges->size()); + (*kbest_edges)[yield[i]->id_] = true; + } + } + if (!map_instead_of_viterbi || edges) { + if (kbest_edges) edges = kbest_edges.get(); + new_hg = in_g.CreateViterbiHypergraph(edges); + for (int i = 0; i < new_hg->edges_.size(); ++i) + new_hg->edges_[i].edge_prob_ = prob_t::One(); + g = new_hg.get(); + } + + vector edge_posteriors(g->edges_.size(), prob_t::Zero()); + vector trg_sent; + vector src_sent; + if (fix_up_src_spans) { + ViterbiESentence(*g, &src_sent); + } else { + src_sent.resize(src_lattice.size()); + for (int i = 0; i < src_sent.size(); ++i) + src_sent[i] = src_lattice[i][0].label; + } + + ViterbiFSentence(*g, &trg_sent); + + if (edges || !map_instead_of_viterbi) { + for (int i = 0; i < edge_posteriors.size(); ++i) + edge_posteriors[i] = prob_t::One(); + } else { + SparseVector posts; + const prob_t z = InsideOutside, TransitionEventWeightFunction>(*g, &posts); + for (int i = 0; i < edge_posteriors.size(); ++i) + edge_posteriors[i] = posts.value(i) / z; + } + vector > src_cov(g->edges_.size()); + vector > trg_cov(g->edges_.size()); + TargetEdgeCoveragesUsingTree(*g, &trg_cov); + + if (fix_up_src_spans) + SourceEdgeCoveragesUsingTree(*g, &src_cov); + else + SourceEdgeCoveragesUsingParseIndices(*g, &src_cov); + + // figure out the src and reference size; + int src_size = src_sent.size(); + int ref_size = trg_sent.size(); + Array2D align(src_size + 1, ref_size, prob_t::Zero()); + for (int c = 0; c < g->edges_.size(); ++c) { + const prob_t& p = edge_posteriors[c]; + const set& srcs = src_cov[c]; + const set& trgs = trg_cov[c]; + for (set::const_iterator si = srcs.begin(); + si != srcs.end(); ++si) { + for (set::const_iterator ti = trgs.begin(); + ti != trgs.end(); ++ti) { + align(*si + 1, *ti) += p; + } + } + } + new_hg.reset(); + //if (g != &in_g) { g.reset(); } + + prob_t threshold(0.9); + const bool use_soft_threshold = true; // TODO configure + + Array2D grid(src_size, ref_size, false); + for (int j = 0; j < ref_size; ++j) { + if (use_soft_threshold) { + threshold = prob_t::Zero(); + for (int i = 0; i <= src_size; ++i) + if (align(i, j) > threshold) threshold = align(i, j); + //threshold *= prob_t(0.99); + } + for (int i = 0; i < src_size; ++i) + grid(i, j) = align(i+1, j) >= threshold; + } + if (out == &cout && k_best < 2) { + // TODO need to do some sort of verbose flag + WriteProbGrid(align, &cerr); + cerr << grid << endl; + } + (*out) << TD::GetString(src_sent) << " ||| " << TD::GetString(trg_sent) << " ||| "; + AlignmentIO::SerializePharaohFormat(grid, out); + } +}; + diff --git a/decoder/aligner.h b/decoder/aligner.h new file mode 100644 index 000000000..d68ceefc7 --- /dev/null +++ b/decoder/aligner.h @@ -0,0 +1,26 @@ +#ifndef ALIGNER_H + +#include +#include +#include +#include "array2d.h" +#include "lattice.h" + +class Hypergraph; +class SentenceMetadata; + +struct AlignerTools { + + // assumption: g contains derivations of input/ref and + // ONLY input/ref. + // if edges is non-NULL, the alignment corresponding to the edge rules will be written + static void WriteAlignment(const Lattice& src, + const Lattice& ref, + const Hypergraph& g, + std::ostream* out, + bool map_instead_of_viterbi = true, + int k_best = 0, + const std::vector* edges = NULL); +}; + +#endif diff --git a/decoder/apply_models.cc b/decoder/apply_models.cc new file mode 100644 index 000000000..18c83fd46 --- /dev/null +++ b/decoder/apply_models.cc @@ -0,0 +1,655 @@ +////TODO: keep model state in forest? + +//TODO: (for many nonterminals, or multi-rescoring pass) either global +//best-first, or group by (NT,span) - use prev forest outside as a (admissable, +//if models are a subset and weights are same) heuristic + +#include "apply_models.h" + +#include +#include +#ifndef HAVE_OLD_CPP +# include +# include +#else +# include +# include +namespace std { using std::tr1::unordered_map; using std::tr1::unordered_set; } +#endif + +#include + +#include "node_state_hash.h" +#include "verbose.h" +#include "hg.h" +#include "ff.h" +#include "ffset.h" + +#define NORMAL_CP 1 +#define FAST_CP 2 +#define FAST_CP_2 3 + +using namespace std; + +struct Candidate; +typedef SmallVectorInt JVector; +typedef vector CandidateHeap; +typedef vector CandidateList; + +// default vector size (* sizeof string is memory used) +static const size_t kRESERVE_NUM_NODES = 500000ul; + +// life cycle: candidates are created, placed on the heap +// and retrieved by their estimated cost, when they're +// retrieved, they're incorporated into the +LM hypergraph +// where they also know the head node index they are +// attached to. After they are added to the +LM hypergraph +// vit_prob_ and est_prob_ fields may be updated as better +// derivations are found (this happens since the successor's +// of derivation d may have a better score- they are +// explored lazily). However, the updates don't happen +// when a candidate is in the heap so maintaining the heap +// property is not an issue. +struct Candidate { + int node_index_; // -1 until incorporated + // into the +LM forest + const Hypergraph::Edge* in_edge_; // in -LM forest + Hypergraph::Edge out_edge_; + FFState state_; + const JVector j_; + prob_t vit_prob_; // these are fixed until the cand + // is popped, then they may be updated + prob_t est_prob_; + + Candidate(const Hypergraph::Edge& e, + const JVector& j, + const Hypergraph& out_hg, + const vector& D, + const FFStates& node_states, + const SentenceMetadata& smeta, + const ModelSet& models, + bool is_goal) : + node_index_(-1), + in_edge_(&e), + j_(j) { + InitializeCandidate(out_hg, smeta, D, node_states, models, is_goal); + } + + // used to query uniqueness + Candidate(const Hypergraph::Edge& e, + const JVector& j) : in_edge_(&e), j_(j) {} + + bool IsIncorporatedIntoHypergraph() const { + return node_index_ >= 0; + } + + void InitializeCandidate(const Hypergraph& out_hg, + const SentenceMetadata& smeta, + const vector >& D, + const FFStates& node_states, + const ModelSet& models, + const bool is_goal) { + const Hypergraph::Edge& in_edge = *in_edge_; + out_edge_.rule_ = in_edge.rule_; + out_edge_.feature_values_ = in_edge.feature_values_; + out_edge_.i_ = in_edge.i_; + out_edge_.j_ = in_edge.j_; + out_edge_.prev_i_ = in_edge.prev_i_; + out_edge_.prev_j_ = in_edge.prev_j_; + Hypergraph::TailNodeVector& tail = out_edge_.tail_nodes_; + tail.resize(j_.size()); + prob_t p = prob_t::One(); + // cerr << "\nEstimating application of " << in_edge.rule_->AsString() << endl; + for (int i = 0; i < tail.size(); ++i) { + const Candidate& ant = *D[in_edge.tail_nodes_[i]][j_[i]]; + assert(ant.IsIncorporatedIntoHypergraph()); + tail[i] = ant.node_index_; + p *= ant.vit_prob_; + } + prob_t edge_estimate = prob_t::One(); + if (is_goal) { + assert(tail.size() == 1); + const FFState& ant_state = node_states[tail.front()]; + models.AddFinalFeatures(ant_state, &out_edge_, smeta); + } else { + models.AddFeaturesToEdge(smeta, out_hg, node_states, &out_edge_, &state_, &edge_estimate); + } + vit_prob_ = out_edge_.edge_prob_ * p; + est_prob_ = vit_prob_ * edge_estimate; + } +}; + +ostream& operator<<(ostream& os, const Candidate& cand) { + os << "CAND["; + if (!cand.IsIncorporatedIntoHypergraph()) { os << "PENDING "; } + else { os << "+LM_node=" << cand.node_index_; } + os << " edge=" << cand.in_edge_->id_; + os << " j=<"; + for (int i = 0; i < cand.j_.size(); ++i) + os << (i==0 ? "" : " ") << cand.j_[i]; + os << "> vit=" << log(cand.vit_prob_); + os << " est=" << log(cand.est_prob_); + return os << ']'; +} + +struct HeapCandCompare { + bool operator()(const Candidate* l, const Candidate* r) const { + return l->est_prob_ < r->est_prob_; + } +}; + +struct EstProbSorter { + bool operator()(const Candidate* l, const Candidate* r) const { + return l->est_prob_ > r->est_prob_; + } +}; + +// the same candidate can be added multiple times if +// j is multidimensional (if you're going NW in Manhattan, you +// can first go north, then west, or you can go west then north) +// this is a hash function on the relevant variables from +// Candidate to enforce this. +struct CandidateUniquenessHash { + size_t operator()(const Candidate* c) const { + size_t x = 5381; + x = ((x << 5) + x) ^ c->in_edge_->id_; + for (int i = 0; i < c->j_.size(); ++i) + x = ((x << 5) + x) ^ c->j_[i]; + return x; + } +}; + +struct CandidateUniquenessEquals { + bool operator()(const Candidate* a, const Candidate* b) const { + return (a->in_edge_ == b->in_edge_) && (a->j_ == b->j_); + } +}; + +typedef unordered_set UniqueCandidateSet; +typedef unordered_map > State2Node; + +class CubePruningRescorer { + +public: + CubePruningRescorer(const ModelSet& m, + const SentenceMetadata& sm, + const Hypergraph& i, + int pop_limit, + Hypergraph* o, + int s = NORMAL_CP ) : + models(m), + smeta(sm), + in(i), + out(*o), + D(in.nodes_.size()), + pop_limit_(pop_limit), + strategy_(s){ + if (!SILENT) cerr << " Applying feature functions (cube pruning, pop_limit = " << pop_limit_ << ')' << endl; + node_states_.reserve(kRESERVE_NUM_NODES); + } + + void Apply() { + int num_nodes = in.nodes_.size(); + assert(num_nodes >= 2); + int goal_id = num_nodes - 1; + int pregoal = goal_id - 1; + assert(in.nodes_[pregoal].out_edges_.size() == 1); + if (!SILENT) cerr << " "; + int has = 0; + for (int i = 0; i < in.nodes_.size(); ++i) { + if (!SILENT) { + int needs = (50 * i / in.nodes_.size()); + while (has < needs) { cerr << '.'; ++has; } + } + if (strategy_==NORMAL_CP){ + KBest(i, i == goal_id); + } + if (strategy_==FAST_CP){ + KBestFast(i, i == goal_id); + } + if (strategy_==FAST_CP_2){ + KBestFast2(i, i == goal_id); + } + } + if (!SILENT) { + cerr << endl; + cerr << " Best path: " << log(D[goal_id].front()->vit_prob_) + << "\t" << log(D[goal_id].front()->est_prob_) << endl; + } + out.PruneUnreachable(D[goal_id].front()->node_index_); + FreeAll(); + } + + private: + void FreeAll() { + for (int i = 0; i < D.size(); ++i) { + CandidateList& D_i = D[i]; + for (int j = 0; j < D_i.size(); ++j) + delete D_i[j]; + } + D.clear(); + } + + void IncorporateIntoPlusLMForest(size_t head_node_hash, Candidate* item, State2Node* s2n, CandidateList* freelist) { + Hypergraph::Edge* new_edge = out.AddEdge(item->out_edge_); + new_edge->edge_prob_ = item->out_edge_.edge_prob_; + + Candidate** o_item_ptr = nullptr; + if (item->state_.size() && models.NeedsStateErasure()) { + // When erasure of certain state bytes is needed, we must make a copy of + // the state instead of doing the erasure in-place because future + // candidates may require the information in the bytes to be erased. + FFState state(item->state_); + models.EraseIgnoredBytes(&state); + o_item_ptr = &(*s2n)[state]; + } else { + o_item_ptr = &(*s2n)[item->state_]; + } + Candidate*& o_item = *o_item_ptr; + + if (!o_item) o_item = item; + + int& node_id = o_item->node_index_; + if (node_id < 0) { + Hypergraph::Node* new_node = out.AddNode(in.nodes_[item->in_edge_->head_node_].cat_); + new_node->node_hash = cdec::HashNode(head_node_hash, item->state_); // ID is combination of existing state + residual state + node_states_.push_back(item->state_); + node_id = new_node->id_; + } +#if 0 + Hypergraph::Node* node = &out.nodes_[node_id]; + out.ConnectEdgeToHeadNode(new_edge, node); +#else + out.ConnectEdgeToHeadNode(new_edge, node_id); +#endif + // update candidate if we have a better derivation + // note: the difference between the vit score and the estimated + // score is the same for all items with a common residual DP + // state + if (item->vit_prob_ > o_item->vit_prob_) { + if (item->state_.size() && models.NeedsStateErasure()) { + // node_states_ should still point to the unerased state. + node_states_[o_item->node_index_] = item->state_; + // sanity check! + FFState item_state(item->state_), o_item_state(o_item->state_); + models.EraseIgnoredBytes(&item_state); + models.EraseIgnoredBytes(&o_item_state); + assert(item_state == o_item_state); + } else { + assert(o_item->state_ == item->state_); // sanity check! + } + + o_item->est_prob_ = item->est_prob_; + o_item->vit_prob_ = item->vit_prob_; + } + if (item != o_item) freelist->push_back(item); + } + + void KBest(const int vert_index, const bool is_goal) { + // cerr << "KBest(" << vert_index << ")\n"; + CandidateList& D_v = D[vert_index]; + assert(D_v.empty()); + const Hypergraph::Node& v = in.nodes_[vert_index]; + // cerr << " has " << v.in_edges_.size() << " in-coming edges\n"; + const vector& in_edges = v.in_edges_; + CandidateHeap cand; + CandidateList freelist; + cand.reserve(in_edges.size()); + UniqueCandidateSet unique_cands; + for (int i = 0; i < in_edges.size(); ++i) { + const Hypergraph::Edge& edge = in.edges_[in_edges[i]]; + const JVector j(edge.tail_nodes_.size(), 0); + cand.push_back(new Candidate(edge, j, out, D, node_states_, smeta, models, is_goal)); + bool is_new = unique_cands.insert(cand.back()).second; + assert(is_new); // these should all be unique! + } +// cerr << " making heap of " << cand.size() << " candidates\n"; + make_heap(cand.begin(), cand.end(), HeapCandCompare()); + State2Node state2node; // "buf" in Figure 2 + int pops = 0; + while(!cand.empty() && pops < pop_limit_) { + pop_heap(cand.begin(), cand.end(), HeapCandCompare()); + Candidate* item = cand.back(); + cand.pop_back(); + // cerr << "POPPED: " << *item << endl; + PushSucc(*item, is_goal, &cand, &unique_cands); + IncorporateIntoPlusLMForest(v.node_hash, item, &state2node, &freelist); + ++pops; + } + D_v.resize(state2node.size()); + int c = 0; + for (State2Node::iterator i = state2node.begin(); i != state2node.end(); ++i) + D_v[c++] = i->second; + sort(D_v.begin(), D_v.end(), EstProbSorter()); + // cerr << " expanded to " << D_v.size() << " nodes\n"; + + for (int i = 0; i < cand.size(); ++i) + delete cand[i]; + // freelist is necessary since even after an item merged, it still stays in + // the unique set so it can't be deleted til now + for (int i = 0; i < freelist.size(); ++i) + delete freelist[i]; + } + + void KBestFast(const int vert_index, const bool is_goal) { + // cerr << "KBest(" << vert_index << ")\n"; + CandidateList& D_v = D[vert_index]; + assert(D_v.empty()); + const Hypergraph::Node& v = in.nodes_[vert_index]; + // cerr << " has " << v.in_edges_.size() << " in-coming edges\n"; + const vector& in_edges = v.in_edges_; + CandidateHeap cand; + CandidateList freelist; + cand.reserve(in_edges.size()); + //init with j<0,0> for all rules-edges that lead to node-(NT-span) + for (int i = 0; i < in_edges.size(); ++i) { + const Hypergraph::Edge& edge = in.edges_[in_edges[i]]; + const JVector j(edge.tail_nodes_.size(), 0); + cand.push_back(new Candidate(edge, j, out, D, node_states_, smeta, models, is_goal)); + } + // cerr << " making heap of " << cand.size() << " candidates\n"; + make_heap(cand.begin(), cand.end(), HeapCandCompare()); + State2Node state2node; // "buf" in Figure 2 + int pops = 0; + while(!cand.empty() && pops < pop_limit_) { + pop_heap(cand.begin(), cand.end(), HeapCandCompare()); + Candidate* item = cand.back(); + cand.pop_back(); + // cerr << "POPPED: " << *item << endl; + + PushSuccFast(*item, is_goal, &cand); + IncorporateIntoPlusLMForest(v.node_hash, item, &state2node, &freelist); + ++pops; + } + D_v.resize(state2node.size()); + int c = 0; + for (auto& i : state2node) { + D_v[c++] = i.second; + // cerr << "MERGED: " << *i.second << endl; + } + //cerr <<"Node id: "<< vert_index<< endl; + //#ifdef MEASURE_CA + // cerr << "countInProcess (pop/tot): node id: " << vert_index << " (" << count_in_process_pop << "/" << count_in_process_tot << ")"<& in_edges = v.in_edges_; + CandidateHeap cand; + CandidateList freelist; + cand.reserve(in_edges.size()); + UniqueCandidateSet unique_accepted; + //init with j<0,0> for all rules-edges that lead to node-(NT-span) + for (int i = 0; i < in_edges.size(); ++i) { + const Hypergraph::Edge& edge = in.edges_[in_edges[i]]; + const JVector j(edge.tail_nodes_.size(), 0); + cand.push_back(new Candidate(edge, j, out, D, node_states_, smeta, models, is_goal)); + } + // cerr << " making heap of " << cand.size() << " candidates\n"; + make_heap(cand.begin(), cand.end(), HeapCandCompare()); + State2Node state2node; // "buf" in Figure 2 + int pops = 0; + while(!cand.empty() && pops < pop_limit_) { + pop_heap(cand.begin(), cand.end(), HeapCandCompare()); + Candidate* item = cand.back(); + cand.pop_back(); + bool is_new = unique_accepted.insert(item).second; + assert(is_new); // these should all be unique! + // cerr << "POPPED: " << *item << endl; + + PushSuccFast2(*item, is_goal, &cand, &unique_accepted); + IncorporateIntoPlusLMForest(v.node_hash, item, &state2node, &freelist); + ++pops; + } + D_v.resize(state2node.size()); + int c = 0; + for (State2Node::iterator i = state2node.begin(); i != state2node.end(); ++i){ + D_v[c++] = i->second; + // cerr << "MERGED: " << *i->second << endl; + } + //cerr <<"Node id: "<< vert_index<< endl; + //#ifdef MEASURE_CA + // cerr << "countInProcess (pop/tot): node id: " << vert_index << " (" << count_in_process_pop << "/" << count_in_process_tot << ")"<tail_nodes_[i]].size()) { + Candidate query_unique(*item.in_edge_, j); + if (cs->count(&query_unique) == 0) { + Candidate* new_cand = new Candidate(*item.in_edge_, j, out, D, node_states_, smeta, models, is_goal); + cand.push_back(new_cand); + push_heap(cand.begin(), cand.end(), HeapCandCompare()); + bool is_new = cs->insert(new_cand).second; + assert(is_new); // insert into uniqueness set, sanity check + } + } + } + } + + //PushSucc following unique ancestor generation function + void PushSuccFast(const Candidate& item, const bool is_goal, CandidateHeap* pcand){ + CandidateHeap& cand = *pcand; + for (int i = 0; i < item.j_.size(); ++i) { + JVector j = item.j_; + ++j[i]; + if (j[i] < D[item.in_edge_->tail_nodes_[i]].size()) { + Candidate* new_cand = new Candidate(*item.in_edge_, j, out, D, node_states_, smeta, models, is_goal); + cand.push_back(new_cand); + push_heap(cand.begin(), cand.end(), HeapCandCompare()); + } + if(item.j_[i]!=0){ + return; + } + } + } + + //PushSucc only if all ancest Cand are added + void PushSuccFast2(const Candidate& item, const bool is_goal, CandidateHeap* pcand, UniqueCandidateSet* ps){ + CandidateHeap& cand = *pcand; + for (int i = 0; i < item.j_.size(); ++i) { + JVector j = item.j_; + ++j[i]; + if (j[i] < D[item.in_edge_->tail_nodes_[i]].size()) { + Candidate query_unique(*item.in_edge_, j); + if (HasAllAncestors(&query_unique,ps)) { + Candidate* new_cand = new Candidate(*item.in_edge_, j, out, D, node_states_, smeta, models, is_goal); + cand.push_back(new_cand); + push_heap(cand.begin(), cand.end(), HeapCandCompare()); + } + } + } + } + + bool HasAllAncestors(const Candidate* item, UniqueCandidateSet* cs){ + for (int i = 0; i < item->j_.size(); ++i) { + JVector j = item->j_; + --j[i]; + if (j[i] >=0) { + Candidate query_unique(*item->in_edge_, j); + if (cs->count(&query_unique) == 0) { + return false; + } + } + } + return true; + } + + const ModelSet& models; + const SentenceMetadata& smeta; + const Hypergraph& in; + Hypergraph& out; + + vector D; // maps nodes in in-HG to the + // equivalent nodes (many due to state + // splits) in the out-HG. + FFStates node_states_; // for each node in the out-HG what is + // its q function value? + const int pop_limit_; + const int strategy_; //switch Cube Pruning strategy: 1 normal, 2 fast (alg 2), 3 fast_2 (alg 3). (see: Gesmundo A., Henderson J,. Faster Cube Pruning, IWSLT 2010) +}; + +struct NoPruningRescorer { + NoPruningRescorer(const ModelSet& m, const SentenceMetadata &sm, const Hypergraph& i, Hypergraph* o) : + models(m), + smeta(sm), + in(i), + out(*o), + nodemap(i.nodes_.size()) { + if (!SILENT) cerr << " Rescoring forest (full intersection)\n"; + node_states_.reserve(kRESERVE_NUM_NODES); + } + + typedef unordered_map > State2NodeIndex; + + void ExpandEdge(const Hypergraph::Edge& in_edge, bool is_goal, size_t head_node_hash, State2NodeIndex* state2node) { + const int arity = in_edge.Arity(); + Hypergraph::TailNodeVector ends(arity); + for (int i = 0; i < arity; ++i) + ends[i] = nodemap[in_edge.tail_nodes_[i]].size(); + + Hypergraph::TailNodeVector tail_iter(arity, 0); + bool done = false; + while (!done) { + Hypergraph::TailNodeVector tail(arity); + for (int i = 0; i < arity; ++i) + tail[i] = nodemap[in_edge.tail_nodes_[i]][tail_iter[i]]; + Hypergraph::Edge* new_edge = out.AddEdge(in_edge, tail); + FFState head_state; + if (is_goal) { + assert(tail.size() == 1); + const FFState& ant_state = node_states_[tail.front()]; + models.AddFinalFeatures(ant_state, new_edge,smeta); + } else { + prob_t edge_estimate; // this is a full intersection, so we disregard this + models.AddFeaturesToEdge(smeta, out, node_states_, new_edge, &head_state, &edge_estimate); + } + int& head_plus1 = (*state2node)[head_state]; + if (!head_plus1) { + HG::Node* new_node = out.AddNode(in_edge.rule_->GetLHS()); + new_node->node_hash = cdec::HashNode(head_node_hash, head_state); // ID is combination of existing state + residual state + head_plus1 = new_node->id_ + 1; + node_states_.push_back(head_state); + nodemap[in_edge.head_node_].push_back(head_plus1 - 1); + } + const int head_index = head_plus1 - 1; + out.ConnectEdgeToHeadNode(new_edge->id_, head_index); + + int ii = 0; + for (; ii < arity; ++ii) { + ++tail_iter[ii]; + if (tail_iter[ii] < ends[ii]) break; + tail_iter[ii] = 0; + } + done = (ii == arity); + } + } + + void ProcessOneNode(const int node_num, const bool is_goal) { + State2NodeIndex state2node; + const Hypergraph::Node& node = in.nodes_[node_num]; + for (int i = 0; i < node.in_edges_.size(); ++i) { + const Hypergraph::Edge& edge = in.edges_[node.in_edges_[i]]; + ExpandEdge(edge, is_goal, node.node_hash, &state2node); + } + } + + void Apply() { + int num_nodes = in.nodes_.size(); + int goal_id = num_nodes - 1; + int pregoal = goal_id - 1; + assert(in.nodes_[pregoal].out_edges_.size() == 1); + if (!SILENT) cerr << " "; + int has = 0; + for (int i = 0; i < in.nodes_.size(); ++i) { + if (!SILENT) { + int needs = (50 * i / in.nodes_.size()); + while (has < needs) { cerr << '.'; ++has; } + } + ProcessOneNode(i, i == goal_id); + } + if (!SILENT) cerr << endl; + } + + private: + const ModelSet& models; + const SentenceMetadata& smeta; + const Hypergraph& in; + Hypergraph& out; + + vector > nodemap; + FFStates node_states_; // for each node in the out-HG what is + // its q function value? +}; + +// each node in the graph has one of these, it keeps track of +void ApplyModelSet(const Hypergraph& in, + const SentenceMetadata& smeta, + const ModelSet& models, + const IntersectionConfiguration& config, + Hypergraph* out) { + //force exhaustive if there's no state req. for model + if (models.stateless() || config.algorithm == IntersectionConfiguration::FULL) { + NoPruningRescorer ma(models, smeta, in, out); // avoid overhead of best-first when no state + ma.Apply(); + } else if (config.algorithm == IntersectionConfiguration::CUBE || + config.algorithm == IntersectionConfiguration::FAST_CUBE_PRUNING || + config.algorithm == + IntersectionConfiguration::FAST_CUBE_PRUNING_2) { + int pl = config.pop_limit; + const int max_pl_for_large=50; + if (pl > max_pl_for_large && in.nodes_.size() > 80000) { + pl = max_pl_for_large; + cerr << " Note: reducing pop_limit to " << pl << " for very large forest\n"; + } + if (config.algorithm == IntersectionConfiguration::CUBE) { + CubePruningRescorer ma(models, smeta, in, pl, out); + ma.Apply(); + } + else if (config.algorithm == IntersectionConfiguration::FAST_CUBE_PRUNING){ + CubePruningRescorer ma(models, smeta, in, pl, out, FAST_CP); + ma.Apply(); + } + else if (config.algorithm == IntersectionConfiguration::FAST_CUBE_PRUNING_2){ + CubePruningRescorer ma(models, smeta, in, pl, out, FAST_CP_2); + ma.Apply(); + } + + } else { + cerr << "Don't understand intersection algorithm " << config.algorithm << endl; + exit(1); + } + out->is_linear_chain_ = in.is_linear_chain_; // TODO remove when this is computed + // automatically +} diff --git a/decoder/apply_models.h b/decoder/apply_models.h new file mode 100644 index 000000000..bfb37df18 --- /dev/null +++ b/decoder/apply_models.h @@ -0,0 +1,43 @@ +#ifndef APPLY_MODELS_H_ +#define APPLY_MODELS_H_ + +#include + +class ModelSet; +class Hypergraph; +class SentenceMetadata; + +struct exhaustive_t {}; + +struct IntersectionConfiguration { +enum { + FULL, + CUBE, + FAST_CUBE_PRUNING, + FAST_CUBE_PRUNING_2, + N_ALGORITHMS +}; + + const int algorithm; // 0 = full intersection, 1 = cube pruning + const int pop_limit; // max number of pops off the heap at each node + IntersectionConfiguration(int alg, int k) : algorithm(alg), pop_limit(k) {} + IntersectionConfiguration(exhaustive_t /* t */) : algorithm(0), pop_limit() {} +}; + +inline std::ostream& operator<<(std::ostream& os, const IntersectionConfiguration& c) { + if (c.algorithm == 0) { os << "FULL"; } + else if (c.algorithm == 1) { os << "CUBE:k=" << c.pop_limit; } + else if (c.algorithm == 2) { os << "FAST_CUBE_PRUNING"; } + else if (c.algorithm == 3) { os << "FAST_CUBE_PRUNING_2"; } + else if (c.algorithm == 4) { os << "N_ALGORITHMS"; } + else os << "OTHER"; + return os; +} + +void ApplyModelSet(const Hypergraph& in, + const SentenceMetadata& smeta, + const ModelSet& models, + const IntersectionConfiguration& config, + Hypergraph* out); + +#endif diff --git a/decoder/bottom_up_parser-rs.cc b/decoder/bottom_up_parser-rs.cc new file mode 100644 index 000000000..863d7e2fd --- /dev/null +++ b/decoder/bottom_up_parser-rs.cc @@ -0,0 +1,340 @@ +#include "bottom_up_parser-rs.h" + +#include +#include + +#include "node_state_hash.h" +#include "nt_span.h" +#include "hg.h" +#include "array2d.h" +#include "tdict.h" +#include "verbose.h" + +using namespace std; + +static WordID kEPS = 0; + +struct RSActiveItem; +class RSChart { + public: + RSChart(const string& goal, + const vector& grammars, + const Lattice& input, + Hypergraph* forest); + ~RSChart(); + + void AddToChart(const RSActiveItem& x, int i, int j); + void ConsumeTerminal(const RSActiveItem& x, int i, int j, int k); + void ConsumeNonTerminal(const RSActiveItem& x, int i, int j, int k); + bool Parse(); + inline bool GoalFound() const { return goal_idx_ >= 0; } + inline int GetGoalIndex() const { return goal_idx_; } + + private: + void ApplyRules(const int i, + const int j, + const RuleBin* rules, + const Hypergraph::TailNodeVector& tail, + const SparseVector& lattice_feats); + + // returns true if a new node was added to the chart + // false otherwise + bool ApplyRule(const int i, + const int j, + const TRulePtr& r, + const Hypergraph::TailNodeVector& ant_nodes, + const SparseVector& lattice_feats); + + void ApplyUnaryRules(const int i, const int j, const WordID& cat, unsigned nodeidx); + void TopoSortUnaries(); + + const vector& grammars_; + const Lattice& input_; + Hypergraph* forest_; + Array2D> chart_; // chart_(i,j) is the list of nodes (represented + // by their index in forest_->nodes_) derived spanning i,j + typedef map Cat2NodeMap; + Array2D nodemap_; + const WordID goal_cat_; // category that is being searched for at [0,n] + TRulePtr goal_rule_; + int goal_idx_; // index of goal node, if found + const int lc_fid_; + vector unaries_; // topologically sorted list of unary rules from all grammars + + static WordID kGOAL; // [Goal] +}; + +WordID RSChart::kGOAL = 0; + +// "a type-2 is identified by a trie node, an array of back-pointers to antecedent cells, and a span" +struct RSActiveItem { + explicit RSActiveItem(const GrammarIter* g, int i) : + gptr_(g), ant_nodes_(), lattice_feats(), i_(i) {} + void ExtendTerminal(int symbol, const SparseVector& src_feats) { + lattice_feats += src_feats; + if (symbol != kEPS) + gptr_ = gptr_->Extend(symbol); + } + void ExtendNonTerminal(const Hypergraph* hg, int node_index) { + gptr_ = gptr_->Extend(hg->nodes_[node_index].cat_); + ant_nodes_.push_back(node_index); + } + // returns false if the extension has failed + explicit operator bool() const { + return gptr_; + } + const GrammarIter* gptr_; + Hypergraph::TailNodeVector ant_nodes_; + SparseVector lattice_feats; + short i_; +}; + +// some notes on the implementation +// "X" in Rico's Algorithm 2 roughly looks like it is just a pointer into a grammar +// trie, but it is actually a full "dotted item" since it needs to contain the information +// to build the hypergraph (i.e., it must remember the antecedent nodes and where they are, +// also any information about the path costs). + +RSChart::RSChart(const string& goal, + const vector& grammars, + const Lattice& input, + Hypergraph* forest) : + grammars_(grammars), + input_(input), + forest_(forest), + chart_(input.size()+1, input.size()+1), + nodemap_(input.size()+1, input.size()+1), + goal_cat_(TD::Convert(goal) * -1), + goal_rule_(new TRule("[Goal] ||| [" + goal + "] ||| [1]")), + goal_idx_(-1), + lc_fid_(FD::Convert("LatticeCost")), + unaries_() { + for (unsigned i = 0; i < grammars_.size(); ++i) { + const vector& u = grammars_[i]->GetAllUnaryRules(); + for (unsigned j = 0; j < u.size(); ++j) + unaries_.push_back(u[j]); + } + TopoSortUnaries(); + if (!kGOAL) kGOAL = TD::Convert("Goal") * -1; + if (!SILENT) cerr << " Goal category: [" << goal << ']' << endl; +} + +static bool TopoSortVisit(int node, vector& u, const map >& g, map& mark) { + if (mark[node] == 1) { + cerr << "[ERROR] Unary rule cycle detected involving [" << TD::Convert(-node) << "]\n"; + return false; // cycle detected + } else if (mark[node] == 2) { + return true; // already been + } + mark[node] = 1; + const map >::const_iterator nit = g.find(node); + if (nit != g.end()) { + const vector& edges = nit->second; + vector okay(edges.size(), true); + for (unsigned i = 0; i < edges.size(); ++i) { + okay[i] = TopoSortVisit(edges[i]->lhs_, u, g, mark); + if (!okay[i]) { + cerr << "[ERROR] Unary rule cycle detected, removing: " << edges[i]->AsString() << endl; + } + } + for (unsigned i = 0; i < edges.size(); ++i) { + if (okay[i]) u.push_back(edges[i]); + //if (okay[i]) cerr << "UNARY: " << edges[i]->AsString() << endl; + } + } + mark[node] = 2; + return true; +} + +void RSChart::TopoSortUnaries() { + vector u(unaries_.size()); u.clear(); + map > g; + map mark; + //cerr << "GOAL=" << TD::Convert(-goal_cat_) << endl; + mark[goal_cat_] = 2; + for (unsigned i = 0; i < unaries_.size(); ++i) { + //cerr << "Adding: " << unaries_[i]->AsString() << endl; + g[unaries_[i]->f()[0]].push_back(unaries_[i]); + } + //m[unaries_[i]->lhs_].push_back(unaries_[i]); + for (map >::iterator it = g.begin(); it != g.end(); ++it) { + //cerr << "PROC: " << TD::Convert(-it->first) << endl; + if (mark[it->first] > 0) { + //cerr << "Already saw [" << TD::Convert(-it->first) << "]\n"; + } else { + TopoSortVisit(it->first, u, g, mark); + } + } + unaries_.clear(); + for (int i = u.size() - 1; i >= 0; --i) + unaries_.push_back(u[i]); +} + +bool RSChart::ApplyRule(const int i, + const int j, + const TRulePtr& r, + const Hypergraph::TailNodeVector& ant_nodes, + const SparseVector& lattice_feats) { + Hypergraph::Edge* new_edge = forest_->AddEdge(r, ant_nodes); + //cerr << i << " " << j << ": APPLYING RULE: " << r->AsString() << endl; + new_edge->prev_i_ = r->prev_i; + new_edge->prev_j_ = r->prev_j; + new_edge->i_ = i; + new_edge->j_ = j; + new_edge->feature_values_ = r->GetFeatureValues(); + new_edge->feature_values_ += lattice_feats; + Cat2NodeMap& c2n = nodemap_(i,j); + const bool is_goal = (r->GetLHS() == kGOAL); + const Cat2NodeMap::iterator ni = c2n.find(r->GetLHS()); + Hypergraph::Node* node = NULL; + bool added_node = false; + if (ni == c2n.end()) { + //cerr << "(" << i << "," << j << ") => " << TD::Convert(-r->GetLHS()) << endl; + added_node = true; + node = forest_->AddNode(r->GetLHS()); + c2n[r->GetLHS()] = node->id_; + if (is_goal) { + assert(goal_idx_ == -1); + goal_idx_ = node->id_; + } else { + chart_(i,j).push_back(node->id_); + } + } else { + node = &forest_->nodes_[ni->second]; + } + forest_->ConnectEdgeToHeadNode(new_edge, node); + return added_node; +} + +void RSChart::ApplyRules(const int i, + const int j, + const RuleBin* rules, + const Hypergraph::TailNodeVector& tail, + const SparseVector& lattice_feats) { + const int n = rules->GetNumRules(); + //cerr << i << " " << j << ": NUM RULES: " << n << endl; + for (int k = 0; k < n; ++k) { + //cerr << i << " " << j << ": R=" << rules->GetIthRule(k)->AsString() << endl; + TRulePtr rule = rules->GetIthRule(k); + // apply rule, and if we create a new node, apply any necessary + // unary rules + if (ApplyRule(i, j, rule, tail, lattice_feats)) { + unsigned nodeidx = nodemap_(i,j)[rule->lhs_]; + ApplyUnaryRules(i, j, rule->lhs_, nodeidx); + } + } +} + +void RSChart::ApplyUnaryRules(const int i, const int j, const WordID& cat, unsigned nodeidx) { + for (unsigned ri = 0; ri < unaries_.size(); ++ri) { + //cerr << "At (" << i << "," << j << "): applying " << unaries_[ri]->AsString() << endl; + if (unaries_[ri]->f()[0] == cat) { + //cerr << " --MATCH\n"; + WordID new_lhs = unaries_[ri]->GetLHS(); + const Hypergraph::TailNodeVector ant(1, nodeidx); + if (ApplyRule(i, j, unaries_[ri], ant, SparseVector())) { + //cerr << "(" << i << "," << j << ") " << TD::Convert(-cat) << " ---> " << TD::Convert(-new_lhs) << endl; + unsigned nodeidx = nodemap_(i,j)[new_lhs]; + ApplyUnaryRules(i, j, new_lhs, nodeidx); + } + } + } +} + +void RSChart::AddToChart(const RSActiveItem& x, int i, int j) { + // deal with completed rules + const RuleBin* rb = x.gptr_->GetRules(); + if (rb) ApplyRules(i, j, rb, x.ant_nodes_, x.lattice_feats); + + //cerr << "Rules applied ... looking for extensions to consume for span (" << i << "," << j << ")\n"; + // continue looking for extensions of the rule to the right + for (unsigned k = j+1; k <= input_.size(); ++k) { + ConsumeTerminal(x, i, j, k); + ConsumeNonTerminal(x, i, j, k); + } +} + +void RSChart::ConsumeTerminal(const RSActiveItem& x, int i, int j, int k) { + //cerr << "ConsumeT(" << i << "," << j << "," << k << "):\n"; + + const unsigned check_edge_len = k - j; + // long-term TODO preindex this search so i->len->words is constant time rather than fan out + for (auto& in_edge : input_[j]) { + if (in_edge.dist2next == check_edge_len) { + //cerr << " Found word spanning (" << j << "," << k << ") in input, symbol=" << TD::Convert(in_edge.label) << endl; + RSActiveItem copy = x; + copy.ExtendTerminal(in_edge.label, in_edge.features); + if (copy) AddToChart(copy, i, k); + } + } +} + +void RSChart::ConsumeNonTerminal(const RSActiveItem& x, int i, int j, int k) { + //cerr << "ConsumeNT(" << i << "," << j << "," << k << "):\n"; + for (auto& nodeidx : chart_(j,k)) { + //cerr << " Found completed NT in (" << j << "," << k << ") of type " << TD::Convert(-forest_->nodes_[nodeidx].cat_) << endl; + RSActiveItem copy = x; + copy.ExtendNonTerminal(forest_, nodeidx); + if (copy) AddToChart(copy, i, k); + } +} + +bool RSChart::Parse() { + size_t in_size_2 = input_.size() * input_.size(); + forest_->nodes_.reserve(in_size_2 * 2); + size_t res = min(static_cast(2000000), static_cast(in_size_2 * 1000)); + forest_->edges_.reserve(res); + goal_idx_ = -1; + const int N = input_.size(); + for (int i = N - 1; i >= 0; --i) { + for (int j = i + 1; j <= N; ++j) { + for (unsigned gi = 0; gi < grammars_.size(); ++gi) { + RSActiveItem item(grammars_[gi]->GetRoot(), i); + ConsumeTerminal(item, i, i, j); + } + for (unsigned gi = 0; gi < grammars_.size(); ++gi) { + RSActiveItem item(grammars_[gi]->GetRoot(), i); + ConsumeNonTerminal(item, i, i, j); + } + } + } + + // look for goal + const vector& dh = chart_(0, input_.size()); + for (unsigned di = 0; di < dh.size(); ++di) { + const Hypergraph::Node& node = forest_->nodes_[dh[di]]; + if (node.cat_ == goal_cat_) { + Hypergraph::TailNodeVector ant(1, node.id_); + ApplyRule(0, input_.size(), goal_rule_, ant, SparseVector()); + } + } + if (!SILENT) cerr << endl; + + if (GoalFound()) + forest_->PruneUnreachable(forest_->nodes_.size() - 1); + return GoalFound(); +} + +RSChart::~RSChart() {} + +RSExhaustiveBottomUpParser::RSExhaustiveBottomUpParser( + const string& goal_sym, + const vector& grammars) : + goal_sym_(goal_sym), + grammars_(grammars) {} + +bool RSExhaustiveBottomUpParser::Parse(const Lattice& input, + Hypergraph* forest) const { + kEPS = TD::Convert("*EPS*"); + RSChart chart(goal_sym_, grammars_, input, forest); + const bool result = chart.Parse(); + + if (result) { + for (auto& node : forest->nodes_) { + Span prev; + const Span s = forest->NodeSpan(node.id_, &prev); + node.node_hash = cdec::HashNode(node.cat_, s.l, s.r, prev.l, prev.r); + } + } + return result; +} diff --git a/decoder/bottom_up_parser-rs.h b/decoder/bottom_up_parser-rs.h new file mode 100644 index 000000000..2e271e997 --- /dev/null +++ b/decoder/bottom_up_parser-rs.h @@ -0,0 +1,29 @@ +#ifndef RSBOTTOM_UP_PARSER_H_ +#define RSBOTTOM_UP_PARSER_H_ + +#include +#include + +#include "lattice.h" +#include "grammar.h" + +class Hypergraph; + +// implementation of Sennrich (2014) parser +// http://aclweb.org/anthology/W/W14/W14-4011.pdf +class RSExhaustiveBottomUpParser { + public: + RSExhaustiveBottomUpParser(const std::string& goal_sym, + const std::vector& grammars); + + // returns true if goal reached spanning the full input + // forest contains the full (i.e., unpruned) parse forest + bool Parse(const Lattice& input, + Hypergraph* forest) const; + + private: + const std::string goal_sym_; + const std::vector grammars_; +}; + +#endif diff --git a/decoder/bottom_up_parser.cc b/decoder/bottom_up_parser.cc new file mode 100644 index 000000000..7ce8e09d3 --- /dev/null +++ b/decoder/bottom_up_parser.cc @@ -0,0 +1,367 @@ +//TODO: when using many nonterminals, group passive edges for a span (treat all as a single X for the active items). + +//TODO: figure out what cdyer was talking about when he said that having unary rules A->B and B->A, doesn't make cycles appear in result provided rules are sorted in some way (that they typically are) + +#include "bottom_up_parser.h" + +#include +#include + +#include "node_state_hash.h" +#include "nt_span.h" +#include "hg.h" +#include "array2d.h" +#include "tdict.h" +#include "verbose.h" + +using namespace std; + +static WordID kEPS = 0; + +class ActiveChart; +class PassiveChart { + public: + PassiveChart(const string& goal, + const vector& grammars, + const Lattice& input, + Hypergraph* forest); + ~PassiveChart(); + + inline const vector& operator()(int i, int j) const { return chart_(i,j); } + bool Parse(); + inline int size() const { return chart_.width(); } + inline bool GoalFound() const { return goal_idx_ >= 0; } + inline int GetGoalIndex() const { return goal_idx_; } + + private: + void ApplyRules(const int i, + const int j, + const RuleBin* rules, + const Hypergraph::TailNodeVector& tail, + const SparseVector& lattice_feats); + + void ApplyRule(const int i, + const int j, + const TRulePtr& r, + const Hypergraph::TailNodeVector& ant_nodes, + const SparseVector& lattice_feats); + + void ApplyUnaryRules(const int i, const int j); + void TopoSortUnaries(); + + const vector& grammars_; + const Lattice& input_; + Hypergraph* forest_; + Array2D > chart_; // chart_(i,j) is the list of nodes derived spanning i,j + typedef map Cat2NodeMap; + Array2D nodemap_; + vector act_chart_; + const WordID goal_cat_; // category that is being searched for at [0,n] + TRulePtr goal_rule_; + int goal_idx_; // index of goal node, if found + vector unaries_; // topologically sorted list of unary rules from all grammars + + static WordID kGOAL; // [Goal] +}; + +WordID PassiveChart::kGOAL = 0; + +class ActiveChart { + public: + ActiveChart(const Hypergraph* hg, const PassiveChart& psv_chart) : + hg_(hg), + act_chart_(psv_chart.size(), psv_chart.size()), psv_chart_(psv_chart) {} + + struct ActiveItem { + ActiveItem(const GrammarIter* g, const Hypergraph::TailNodeVector& a, const SparseVector& lfeats) : + gptr_(g), ant_nodes_(a), lattice_feats(lfeats) {} + explicit ActiveItem(const GrammarIter* g) : + gptr_(g), ant_nodes_(), lattice_feats() {} + + void ExtendTerminal(int symbol, const SparseVector& src_feats, vector* out_cell) const { + if (symbol == kEPS) { + out_cell->push_back(ActiveItem(gptr_, ant_nodes_, lattice_feats + src_feats)); + } else { + const GrammarIter* ni = gptr_->Extend(symbol); + if (ni) + out_cell->push_back(ActiveItem(ni, ant_nodes_, lattice_feats + src_feats)); + } + } + void ExtendNonTerminal(const Hypergraph* hg, int node_index, vector* out_cell) const { + int symbol = hg->nodes_[node_index].cat_; + const GrammarIter* ni = gptr_->Extend(symbol); + if (!ni) return; + Hypergraph::TailNodeVector na(ant_nodes_.size() + 1); + for (unsigned i = 0; i < ant_nodes_.size(); ++i) + na[i] = ant_nodes_[i]; + na[ant_nodes_.size()] = node_index; + out_cell->push_back(ActiveItem(ni, na, lattice_feats)); + } + + const GrammarIter* gptr_; + Hypergraph::TailNodeVector ant_nodes_; + SparseVector lattice_feats; + }; + + inline const vector& operator()(int i, int j) const { return act_chart_(i,j); } + void SeedActiveChart(const Grammar& g) { + int size = act_chart_.width(); + for (int i = 0; i < size; ++i) + if (g.HasRuleForSpan(i,i,0)) + act_chart_(i,i).push_back(ActiveItem(g.GetRoot())); + } + + void ExtendActiveItems(int i, int k, int j) { + //cerr << " LOOK(" << i << "," << k << ") for completed items in (" << k << "," << j << ")\n"; + vector& cell = act_chart_(i,j); + const vector& icell = act_chart_(i,k); + const vector& idxs = psv_chart_(k, j); + //if (!idxs.empty()) { cerr << "FOUND IN (" << k << "," << j << ")\n"; } + for (vector::const_iterator di = icell.begin(); di != icell.end(); ++di) { + for (vector::const_iterator ni = idxs.begin(); ni != idxs.end(); ++ni) { + di->ExtendNonTerminal(hg_, *ni, &cell); + } + } + } + + void AdvanceDotsForAllItemsInCell(int i, int j, const vector >& input) { + //cerr << "ADVANCE(" << i << "," << j << ")\n"; + for (int k=i+1; k < j; ++k) + ExtendActiveItems(i, k, j); + + const vector& out_arcs = input[j-1]; + for (vector::const_iterator ai = out_arcs.begin(); + ai != out_arcs.end(); ++ai) { + const WordID& f = ai->label; + const SparseVector& c = ai->features; + const int& len = ai->dist2next; + //cerr << "F: " << TD::Convert(f) << " dest=" << i << "," << (j+len-1) << endl; + const vector& ec = act_chart_(i, j-1); + //cerr << " SRC=" << i << "," << (j-1) << " [ec=" << ec.size() << "]" << endl; + //if (ec.size() > 0) { cerr << " LC=" << ec[0].lattice_feats << endl; } + for (vector::const_iterator di = ec.begin(); di != ec.end(); ++di) + di->ExtendTerminal(f, c, &act_chart_(i, j + len - 1)); + } + } + + private: + const Hypergraph* hg_; + Array2D > act_chart_; + const PassiveChart& psv_chart_; +}; + +PassiveChart::PassiveChart(const string& goal, + const vector& grammars, + const Lattice& input, + Hypergraph* forest) : + grammars_(grammars), + input_(input), + forest_(forest), + chart_(input.size()+1, input.size()+1), + nodemap_(input.size()+1, input.size()+1), + goal_cat_(TD::Convert(goal) * -1), + goal_rule_(new TRule("[Goal] ||| [" + goal + "] ||| [1]")), + goal_idx_(-1), + unaries_() { + act_chart_.resize(grammars_.size()); + for (unsigned i = 0; i < grammars_.size(); ++i) { + act_chart_[i] = new ActiveChart(forest, *this); + const vector& u = grammars_[i]->GetAllUnaryRules(); + for (unsigned j = 0; j < u.size(); ++j) + unaries_.push_back(u[j]); + } + TopoSortUnaries(); + if (!kGOAL) kGOAL = TD::Convert("Goal") * -1; + if (!SILENT) cerr << " Goal category: [" << goal << ']' << endl; +} + +static bool TopoSortVisit(int node, vector& u, const map >& g, map& mark) { + if (mark[node] == 1) { + cerr << "[ERROR] Unary rule cycle detected involving [" << TD::Convert(-node) << "]\n"; + return false; // cycle detected + } else if (mark[node] == 2) { + return true; // already been + } + mark[node] = 1; + const map >::const_iterator nit = g.find(node); + if (nit != g.end()) { + const vector& edges = nit->second; + vector okay(edges.size(), true); + for (unsigned i = 0; i < edges.size(); ++i) { + okay[i] = TopoSortVisit(edges[i]->lhs_, u, g, mark); + if (!okay[i]) { + cerr << "[ERROR] Unary rule cycle detected, removing: " << edges[i]->AsString() << endl; + } + } + for (unsigned i = 0; i < edges.size(); ++i) { + if (okay[i]) u.push_back(edges[i]); + //if (okay[i]) cerr << "UNARY: " << edges[i]->AsString() << endl; + } + } + mark[node] = 2; + return true; +} + +void PassiveChart::TopoSortUnaries() { + vector u(unaries_.size()); u.clear(); + map > g; + map mark; + //cerr << "GOAL=" << TD::Convert(-goal_cat_) << endl; + mark[goal_cat_] = 2; + for (unsigned i = 0; i < unaries_.size(); ++i) { + //cerr << "Adding: " << unaries_[i]->AsString() << endl; + g[unaries_[i]->f()[0]].push_back(unaries_[i]); + } + //m[unaries_[i]->lhs_].push_back(unaries_[i]); + for (map >::iterator it = g.begin(); it != g.end(); ++it) { + //cerr << "PROC: " << TD::Convert(-it->first) << endl; + if (mark[it->first] > 0) { + //cerr << "Already saw [" << TD::Convert(-it->first) << "]\n"; + } else { + TopoSortVisit(it->first, u, g, mark); + } + } + unaries_.clear(); + for (int i = u.size() - 1; i >= 0; --i) + unaries_.push_back(u[i]); +} + +void PassiveChart::ApplyRule(const int i, + const int j, + const TRulePtr& r, + const Hypergraph::TailNodeVector& ant_nodes, + const SparseVector& lattice_feats) { + Hypergraph::Edge* new_edge = forest_->AddEdge(r, ant_nodes); + // cerr << i << " " << j << ": APPLYING RULE: " << r->AsString() << endl; + new_edge->prev_i_ = r->prev_i; + new_edge->prev_j_ = r->prev_j; + new_edge->i_ = i; + new_edge->j_ = j; + new_edge->feature_values_ = r->GetFeatureValues(); + new_edge->feature_values_ += lattice_feats; + Cat2NodeMap& c2n = nodemap_(i,j); + const bool is_goal = (r->GetLHS() == kGOAL); + const Cat2NodeMap::iterator ni = c2n.find(r->GetLHS()); + Hypergraph::Node* node = NULL; + if (ni == c2n.end()) { + node = forest_->AddNode(r->GetLHS()); + c2n[r->GetLHS()] = node->id_; + if (is_goal) { + assert(goal_idx_ == -1); + goal_idx_ = node->id_; + } else { + chart_(i,j).push_back(node->id_); + } + } else { + node = &forest_->nodes_[ni->second]; + } + forest_->ConnectEdgeToHeadNode(new_edge, node); +} + +void PassiveChart::ApplyRules(const int i, + const int j, + const RuleBin* rules, + const Hypergraph::TailNodeVector& tail, + const SparseVector& lattice_feats) { + const int n = rules->GetNumRules(); + //cerr << i << " " << j << ": NUM RULES: " << n << endl; + for (int k = 0; k < n; ++k) { + //cerr << i << " " << j << ": R=" << rules->GetIthRule(k)->AsString() << endl; + ApplyRule(i, j, rules->GetIthRule(k), tail, lattice_feats); + } +} + +void PassiveChart::ApplyUnaryRules(const int i, const int j) { + const vector& nodes = chart_(i,j); // reference is important! + for (unsigned di = 0; di < nodes.size(); ++di) { + const WordID cat = forest_->nodes_[nodes[di]].cat_; + for (unsigned ri = 0; ri < unaries_.size(); ++ri) { + //cerr << "At (" << i << "," << j << "): applying " << unaries_[ri]->AsString() << endl; + if (unaries_[ri]->f()[0] == cat) { + //cerr << " --MATCH\n"; + const Hypergraph::TailNodeVector ant(1, nodes[di]); + ApplyRule(i, j, unaries_[ri], ant, SparseVector()); // may update nodes + } + } + } +} + +bool PassiveChart::Parse() { + size_t in_size_2 = input_.size() * input_.size(); + forest_->nodes_.reserve(in_size_2 * 2); + size_t res = min(static_cast(2000000), static_cast(in_size_2 * 1000)); + forest_->edges_.reserve(res); + goal_idx_ = -1; + for (unsigned gi = 0; gi < grammars_.size(); ++gi) + act_chart_[gi]->SeedActiveChart(*grammars_[gi]); + + if (!SILENT) cerr << " "; + for (unsigned l=1; lAdvanceDotsForAllItemsInCell(i, j, input_); + + const vector& cell = (*act_chart_[gi])(i,j); + for (vector::const_iterator ai = cell.begin(); + ai != cell.end(); ++ai) { + const RuleBin* rules = (ai->gptr_->GetRules()); + if (!rules) continue; + ApplyRules(i, j, rules, ai->ant_nodes_, ai->lattice_feats); + } + } + } + ApplyUnaryRules(i,j); + + for (unsigned gi = 0; gi < grammars_.size(); ++gi) { + const Grammar& g = *grammars_[gi]; + // deal with non-terminals that were just proved + if (g.HasRuleForSpan(i, j, input_.Distance(i,j))) + act_chart_[gi]->ExtendActiveItems(i, i, j); + } + } + const vector& dh = chart_(0, input_.size()); + for (unsigned di = 0; di < dh.size(); ++di) { + const Hypergraph::Node& node = forest_->nodes_[dh[di]]; + if (node.cat_ == goal_cat_) { + Hypergraph::TailNodeVector ant(1, node.id_); + ApplyRule(0, input_.size(), goal_rule_, ant, SparseVector()); + } + } + } + if (!SILENT) cerr << endl; + + if (GoalFound()) + forest_->PruneUnreachable(forest_->nodes_.size() - 1); + return GoalFound(); +} + +PassiveChart::~PassiveChart() { + for (unsigned i = 0; i < act_chart_.size(); ++i) + delete act_chart_[i]; +} + +ExhaustiveBottomUpParser::ExhaustiveBottomUpParser( + const string& goal_sym, + const vector& grammars) : + goal_sym_(goal_sym), + grammars_(grammars) {} + +bool ExhaustiveBottomUpParser::Parse(const Lattice& input, + Hypergraph* forest) const { + kEPS = TD::Convert("*EPS*"); + PassiveChart chart(goal_sym_, grammars_, input, forest); + const bool result = chart.Parse(); + + if (result) { + for (auto& node : forest->nodes_) { + Span prev; + const Span s = forest->NodeSpan(node.id_, &prev); + node.node_hash = cdec::HashNode(node.cat_, s.l, s.r, prev.l, prev.r); + } + } + return result; +} diff --git a/decoder/bottom_up_parser.h b/decoder/bottom_up_parser.h new file mode 100644 index 000000000..628bb96d1 --- /dev/null +++ b/decoder/bottom_up_parser.h @@ -0,0 +1,27 @@ +#ifndef BOTTOM_UP_PARSER_H_ +#define BOTTOM_UP_PARSER_H_ + +#include +#include + +#include "lattice.h" +#include "grammar.h" + +class Hypergraph; + +class ExhaustiveBottomUpParser { + public: + ExhaustiveBottomUpParser(const std::string& goal_sym, + const std::vector& grammars); + + // returns true if goal reached spanning the full input + // forest contains the full (i.e., unpruned) parse forest + bool Parse(const Lattice& input, + Hypergraph* forest) const; + + private: + const std::string goal_sym_; + const std::vector grammars_; +}; + +#endif diff --git a/decoder/cdec.cc b/decoder/cdec.cc new file mode 100644 index 000000000..cc3fcff11 --- /dev/null +++ b/decoder/cdec.cc @@ -0,0 +1,47 @@ +#include + +#include "filelib.h" +#include "decoder.h" +#include "ff_register.h" +#include "verbose.h" +#include "timing_stats.h" +#include "util/usage.hh" + +using namespace std; + +int main(int argc, char** argv) { + register_feature_functions(); + Decoder decoder(argc, argv); + + const string input = decoder.GetConf()["input"].as(); + const bool show_feature_dictionary = decoder.GetConf().count("show_feature_dictionary"); + if (!SILENT) cerr << "Reading input from " << ((input == "-") ? "STDIN" : input.c_str()) << endl; + ReadFile in_read(input); + istream *in = in_read.stream(); + assert(*in); + + string buf; +#ifdef CP_TIME + clock_t time_cp(0);//, end_cp; +#endif + while(*in) { + getline(*in, buf); + if (buf.empty()) continue; + decoder.Decode(buf); + } + Timer::Summarize(); +#ifdef CP_TIME + cerr << "Time required for Cube Pruning execution: " + << CpTime::Get() + << " seconds." << "\n\n"; +#endif + if (show_feature_dictionary) { + int num = FD::NumFeats(); + for (int i = 1; i < num; ++i) { + cout << FD::Convert(i) << endl; + } + } + util::PrintUsage(std::cerr); + return 0; +} + diff --git a/decoder/cdec_ff.cc b/decoder/cdec_ff.cc new file mode 100644 index 000000000..973a643a3 --- /dev/null +++ b/decoder/cdec_ff.cc @@ -0,0 +1,85 @@ +#include + +#include "ff.h" +#include "ff_basic.h" +#include "ff_context.h" +#include "ff_const_reorder.h" +#include "ff_spans.h" +#include "ff_lm.h" +#include "ff_klm.h" +#include "ff_ngrams.h" +#include "ff_csplit.h" +#include "ff_wordalign.h" +#include "ff_tagger.h" +#include "ff_factory.h" +#include "ff_rules.h" +#include "ff_ruleshape.h" +#include "ff_bleu.h" +#include "ff_soft_syn.h" +#include "ff_soft_syntax.h" +#include "ff_soft_syntax_mindist.h" +#include "ff_source_path.h" +#include "ff_parse_match.h" +#include "ff_source_syntax.h" +#include "ff_source_syntax2.h" +#include "ff_register.h" +#include "ff_charset.h" +#include "ff_wordset.h" +#include "ff_external.h" +#include "ff_lexical.h" + + +void register_feature_functions() { + static bool registered = false; + if (registered) { + assert(!"register_feature_functions() called twice!"); + } + registered = true; + + RegisterFF(); + + RegisterFF(); + RegisterFF(); + RegisterFF(); + RegisterFF(); + RegisterFF(); + + //TODO: use for all features the new Register which requires static FF::usage(false,false) give name + ff_registry.Register("SpanFeatures", new FFFactory()); + ff_registry.Register("NgramFeatures", new FFFactory()); + ff_registry.Register("RuleContextFeatures", new FFFactory()); + ff_registry.Register("RuleIdentityFeatures", new FFFactory()); + ff_registry.Register("ParseMatchFeatures", new FFFactory); + ff_registry.Register("SoftSyntaxFeatures", new FFFactory); + ff_registry.Register("SoftSyntaxFeaturesMindist", new FFFactory); + ff_registry.Register("SourceSyntaxFeatures", new FFFactory); + ff_registry.Register("SourceSpanSizeFeatures", new FFFactory); + ff_registry.Register("SourceSyntaxFeatures2", new FFFactory); + ff_registry.Register("CMR2008ReorderingFeatures", new FFFactory()); + ff_registry.Register("RuleSourceBigramFeatures", new FFFactory()); + ff_registry.Register("RuleTargetBigramFeatures", new FFFactory()); + ff_registry.Register("KLanguageModel", new KLanguageModelFactory()); + ff_registry.Register("NonLatinCount", new FFFactory); + ff_registry.Register("RuleShape", new FFFactory); + ff_registry.Register("RuleShape2", new FFFactory); + ff_registry.Register("RelativeSentencePosition", new FFFactory); + ff_registry.Register("LexNullJump", new FFFactory); + ff_registry.Register("NewJump", new FFFactory); + ff_registry.Register("SourceBigram", new FFFactory); + ff_registry.Register("Fertility", new FFFactory); + ff_registry.Register("BlunsomSynchronousParseHack", new FFFactory); + ff_registry.Register("CSplit_BasicFeatures", new FFFactory); + ff_registry.Register("CSplit_ReverseCharLM", new FFFactory); + ff_registry.Register("Tagger_BigramIndicator", new FFFactory); + ff_registry.Register("LexicalPairIndicator", new FFFactory); + ff_registry.Register("OutputIndicator", new FFFactory); + ff_registry.Register("IdentityCycleDetector", new FFFactory); + ff_registry.Register("InputIndicator", new FFFactory); + ff_registry.Register("LexicalTranslationTrigger", new FFFactory); + ff_registry.Register("WordPairFeatures", new FFFactory); + ff_registry.Register("SourcePathFeatures", new FFFactory); + ff_registry.Register("WordSet", new FFFactory); + ff_registry.Register("ConstReorderFeature", new FFFactory); + ff_registry.Register("External", new FFFactory); + ff_registry.Register("SoftSynFeature", new SoftSynFeatureFactory()); +} diff --git a/decoder/csplit.cc b/decoder/csplit.cc new file mode 100644 index 000000000..7a6ed1020 --- /dev/null +++ b/decoder/csplit.cc @@ -0,0 +1,176 @@ +#include "csplit.h" + +#include + +#include "filelib.h" +#include "stringlib.h" +#include "hg.h" +#include "tdict.h" +#include "grammar.h" +#include "sentence_metadata.h" + +using namespace std; + +struct CompoundSplitImpl { + CompoundSplitImpl(const boost::program_options::variables_map& conf) : + fugen_elements_(true), + min_size_(3), + kXCAT(TD::Convert("X")*-1), + kWORDBREAK_RULE(new TRule("[X] ||| # ||| #")), + kTEMPLATE_RULE(new TRule("[X] ||| [X,1] ? ||| [1] ?")), + kGOAL_RULE(new TRule("[Goal] ||| [X,1] ||| [1]")), + kFUGEN_S(FD::Convert("FugS")), + kFUGEN_N(FD::Convert("FugN")) { + // TODO: use conf to turn fugenelements on and off + } + + void PasteTogetherStrings(const vector& chars, + const int i, + const int j, + string* yield) { + int size = 0; + for (int k=i; kresize(size); + int cur = 0; + for (int k=i; k& chars, + Hypergraph* forest) { + vector nodes(chars.size()+1, -1); + nodes[0] = forest->AddNode(kXCAT)->id_; // source + const int left_rule = forest->AddEdge(kWORDBREAK_RULE, Hypergraph::TailNodeVector())->id_; + forest->ConnectEdgeToHeadNode(left_rule, nodes[0]); + + const int max_split_ = max(static_cast(chars.size()) - min_size_ + 1, 1); + // cerr << "max: " << max_split_ << " " << " min: " << min_size_ << endl; + for (int i = min_size_; i < max_split_; ++i) + nodes[i] = forest->AddNode(kXCAT)->id_; + assert(nodes.back() == -1); + nodes.back() = forest->AddNode(kXCAT)->id_; // sink + + for (int i = 0; i < max_split_; ++i) { + if (nodes[i] < 0) continue; + const int start = min(i + min_size_, static_cast(chars.size())); + for (int j = start; j <= chars.size(); ++j) { + if (nodes[j] < 0) continue; + string yield; + PasteTogetherStrings(chars, i, j, &yield); + // cerr << "[" << i << "," << j << "] " << yield << endl; + TRulePtr rule = TRulePtr(new TRule(*kTEMPLATE_RULE)); + rule->e_[1] = rule->f_[1] = TD::Convert(yield); + // cerr << rule->AsString() << endl; + int edge = forest->AddEdge( + rule, + Hypergraph::TailNodeVector(1, nodes[i]))->id_; + forest->ConnectEdgeToHeadNode(edge, nodes[j]); + forest->edges_[edge].i_ = i; + forest->edges_[edge].j_ = j; + + // handle "fugenelemente" here + // don't delete "fugenelemente" at the end of words + if (fugen_elements_ && j != chars.size()) { + const int len = yield.size(); + string alt; + int fid = 0; + if (len > (min_size_ + 2) && yield[len-1] == 's' && yield[len-2] == 'e') { + alt = yield.substr(0, len - 2); + fid = kFUGEN_S; + } else if (len > (min_size_ + 1) && yield[len-1] == 's') { + alt = yield.substr(0, len - 1); + fid = kFUGEN_S; + } else if (len > (min_size_ + 2) && yield[len-2] == 'e' && yield[len-1] == 'n') { + alt = yield.substr(0, len - 1); + fid = kFUGEN_N; + } + if (alt.size()) { + TRulePtr altrule = TRulePtr(new TRule(*rule)); + altrule->e_[1] = TD::Convert(alt); + // cerr << altrule->AsString() << endl; + int edge = forest->AddEdge( + altrule, + Hypergraph::TailNodeVector(1, nodes[i]))->id_; + forest->ConnectEdgeToHeadNode(edge, nodes[j]); + forest->edges_[edge].feature_values_.set_value(fid, 1.0); + forest->edges_[edge].i_ = i; + forest->edges_[edge].j_ = j; + } + } + } + } + + // add goal rule + Hypergraph::TailNodeVector tail(1, forest->nodes_.size() - 1); + Hypergraph::Node* goal = forest->AddNode(TD::Convert("Goal")*-1); + Hypergraph::Edge* hg_edge = forest->AddEdge(kGOAL_RULE, tail); + forest->ConnectEdgeToHeadNode(hg_edge, goal); + } + private: + const bool fugen_elements_; + const int min_size_; + const WordID kXCAT; + const TRulePtr kWORDBREAK_RULE; + const TRulePtr kTEMPLATE_RULE; + const TRulePtr kGOAL_RULE; + const int kFUGEN_S; + const int kFUGEN_N; +}; + +CompoundSplit::CompoundSplit(const boost::program_options::variables_map& conf) : + pimpl_(new CompoundSplitImpl(conf)) {} + +static void SplitUTF8String(const string& in, vector* out) { + out->resize(in.size()); + int i = 0; + int c = 0; + while (i < in.size()) { + const int len = UTF8Len(in[i]); + assert(len); + (*out)[c] = in.substr(i, len); + ++c; + i += len; + } + out->resize(c); +} + +bool CompoundSplit::TranslateImpl(const string& input, + SentenceMetadata* smeta, + const vector& weights, + Hypergraph* forest) { + if (input.find(" ") != string::npos) { + cerr << " BAD INPUT: " << input << "\n CompoundSplit expects single words\n"; + abort(); + } + vector in; + SplitUTF8String(input, &in); + smeta->SetSourceLength(in.size()); // TODO do utf8 or somethign + for (int i = 0; i < in.size(); ++i) + smeta->src_lattice_.push_back(vector(1, LatticeArc(TD::Convert(in[i]), SparseVector(), 1))); + smeta->ComputeInputLatticeType(); + pimpl_->BuildTrellis(in, forest); + forest->Reweight(weights); + return true; +} + +int CompoundSplit::GetFullWordEdgeIndex(const Hypergraph& forest) { + assert(forest.nodes_.size() > 0); + const vector out_edges = forest.nodes_[0].out_edges_; + int max_edge = -1; + int max_j = -1; + for (int i = 0; i < out_edges.size(); ++i) { + const int j = forest.edges_[out_edges[i]].j_; + if (j > max_j) { + max_j = j; + max_edge = out_edges[i]; + } + } + assert(max_edge >= 0); + assert(max_edge < forest.edges_.size()); + return max_edge; +} + diff --git a/decoder/csplit.h b/decoder/csplit.h new file mode 100644 index 000000000..83d457b8f --- /dev/null +++ b/decoder/csplit.h @@ -0,0 +1,30 @@ +#ifndef CSPLIT_H_ +#define CSPLIT_H_ + +#include "translator.h" +#include "lattice.h" + +// this "translator" takes single words (with NO SPACES) and segments +// them using the approach described in: +// +// C. Dyer. (2009) Using a maximum entropy model to build segmentation +// lattices for MT. In Proceedings of NAACL HLT 2009. +// note, an extra word space marker # is inserted at the left edge of +// the forest! +struct CompoundSplitImpl; +struct CompoundSplit : public Translator { + CompoundSplit(const boost::program_options::variables_map& conf); + bool TranslateImpl(const std::string& input, + SentenceMetadata* smeta, + const std::vector& weights, + Hypergraph* forest); + + // given a forest generated by CompoundSplit::Translate, + // find the edge representing the unsegmented form + static int GetFullWordEdgeIndex(const Hypergraph& forest); + + private: + boost::shared_ptr pimpl_; +}; + +#endif diff --git a/decoder/decoder.cc b/decoder/decoder.cc new file mode 100644 index 000000000..1e6c31943 --- /dev/null +++ b/decoder/decoder.cc @@ -0,0 +1,1128 @@ +#include "decoder.h" + +#ifndef HAVE_OLD_CPP +# include +#else +# include +namespace std { using std::tr1::unordered_map; } +#endif +#include +#include +#include +#include + +#include "stringlib.h" +#include "weights.h" +#include "filelib.h" +#include "fdict.h" +#include "timing_stats.h" +#include "verbose.h" +#include "b64featvector.h" + +#include "translator.h" +#include "phrasebased_translator.h" +#include "tagger.h" +#include "lextrans.h" +#include "lexalign.h" +#include "csplit.h" + +#include "lattice.h" +#include "hg.h" +#include "sentence_metadata.h" +#include "hg_intersect.h" +#include "hg_union.h" + +#include "oracle_bleu.h" +#include "apply_models.h" +#include "ff.h" +#include "ffset.h" +#include "ff_factory.h" +#include "viterbi.h" +#include "kbest.h" +#include "inside_outside.h" +#include "exp_semiring.h" +#include "sentence_metadata.h" +#include "sampler.h" + +#include "forest_writer.h" // TODO this section should probably be handled by an Observer +#include "incremental.h" +#include "hg_io.h" +#include "aligner.h" + +#ifdef CP_TIME + clock_t CpTime::time_; + void CpTime::Add(clock_t x){time_+=x;} + void CpTime::Sub(clock_t x){time_-=x;} + double CpTime::Get(){return (double)(time_)/CLOCKS_PER_SEC;} +#endif + +static const double kMINUS_EPSILON = -1e-6; // don't be too strict + +using namespace std; +namespace po = boost::program_options; + +static bool verbose_feature_functions=true; + +namespace Hack { void MaxTrans(const Hypergraph& in, int beam_size); } +namespace NgramCache { void Clear(); } + +DecoderObserver::~DecoderObserver() {} +void DecoderObserver::NotifyDecodingStart(const SentenceMetadata&) {} +void DecoderObserver::NotifySourceParseFailure(const SentenceMetadata&) {} +void DecoderObserver::NotifyTranslationForest(const SentenceMetadata&, Hypergraph*) {} +void DecoderObserver::NotifyAlignmentFailure(const SentenceMetadata&) {} +void DecoderObserver::NotifyAlignmentForest(const SentenceMetadata&, Hypergraph*) {} +void DecoderObserver::NotifyDecodingComplete(const SentenceMetadata&) {} + +enum SummaryFeature { + kNODE_RISK = 1, + kEDGE_RISK, + kEDGE_PROB +}; + + +struct ELengthWeightFunction { + double operator()(const Hypergraph::Edge& e) const { + return e.rule_->ELength() - e.rule_->Arity(); + } +}; +inline void ShowBanner() { + cerr << "cdec (c) 2009--2014 by Chris Dyer" << endl; +} + +inline string str(char const* name,po::variables_map const& conf) { + return conf[name].as(); +} + + +// print just the --long_opt names suitable for bash compgen +inline void print_options(std::ostream &out,po::options_description const& opts) { + typedef std::vector< boost::shared_ptr > Ds; + Ds const& ds=opts.options(); + out << '"'; + for (unsigned i=0;ilong_name(); + } + out << '"'; +} + +template +inline bool store_conf(po::variables_map const& conf,std::string const& name,V *v) { + if (conf.count(name)) { + *v=conf[name].as(); + return true; + } + return false; +} + +inline boost::shared_ptr make_ff(string const& ffp,bool verbose_feature_functions,char const* pre="") { + string ff, param; + SplitCommandAndParam(ffp, &ff, ¶m); + if (verbose_feature_functions && !SILENT) + cerr << pre << "feature: " << ff; + if (!SILENT) { + if (param.size() > 0) cerr << " (with config parameters '" << param << "')\n"; + else cerr << " (no config parameters)\n"; + } + boost::shared_ptr pf = ff_registry.Create(ff, param); + if (!pf) exit(1); + int nbyte=pf->StateSize(); + if (verbose_feature_functions && !SILENT) + cerr<<"State is "< models; + boost::shared_ptr inter_conf; + vector ffs; + boost::shared_ptr > weight_vector; + int fid_summary; // 0 == no summary feature + double density_prune; // 0 == don't density prune + double beam_prune; // 0 == don't beam prune +}; + +ostream& operator<<(ostream& os, const RescoringPass& rp) { + os << "[num_fn=" << rp.ffs.size(); + if (rp.inter_conf) { os << " int_alg=" << *rp.inter_conf; } + //if (rp.weight_vector.size() > 0) os << " new_weights"; + if (rp.fid_summary) os << " summary_feature=" << FD::Convert(rp.fid_summary); + if (rp.density_prune) os << " density_prune=" << rp.density_prune; + if (rp.beam_prune) os << " beam_prune=" << rp.beam_prune; + os << ']'; + return os; +} + +struct DecoderImpl { + DecoderImpl(po::variables_map& conf, int argc, char** argv, istream* cfg); + ~DecoderImpl(); + bool Decode(const string& input, DecoderObserver*); + vector& CurrentWeightVector() { + return (rescoring_passes.empty() ? *init_weights : *rescoring_passes.back().weight_vector); + } + void SetId(int next_sent_id) { sent_id = next_sent_id - 1; } + + void forest_stats(Hypergraph &forest,string name,bool show_tree,bool show_deriv=false, bool extract_rules=false, boost::shared_ptr extract_file = boost::make_shared()) { + cerr << viterbi_stats(forest,name,true,show_tree,show_deriv,extract_rules, extract_file); + cerr << endl; + } + + bool beam_param(po::variables_map const& conf,string const& name,double *val,bool scale_srclen=false,double srclen=1) { + if (conf.count(name)) { + *val=conf[name].as()*(scale_srclen?srclen:1); + return true; + } + return false; + } + + void maybe_prune(Hypergraph &forest,po::variables_map const& conf,string nbeam,string ndensity,string forestname,double srclen) { + double beam_prune=0,density_prune=0; + bool use_beam_prune=beam_param(conf,nbeam,&beam_prune,conf.count("scale_prune_srclen"),srclen); + bool use_density_prune=beam_param(conf,ndensity,&density_prune); + if (use_beam_prune || use_density_prune) { + double presize=forest.edges_.size(); + vector preserve_mask,*pm=0; + if (conf.count("csplit_preserve_full_word")) { + preserve_mask.resize(forest.edges_.size()); + preserve_mask[CompoundSplit::GetFullWordEdgeIndex(forest)] = true; + pm=&preserve_mask; + } + forest.PruneInsideOutside(beam_prune,density_prune,pm,false,1); + if (!forestname.empty()) forestname=" "+forestname; + if (!SILENT) { + forest_stats(forest," Pruned "+forestname+" forest",false,false); + cerr << " Pruned "< >& ss, int n, vector* out) { + const SampleSet& s = ss[n]; + int i = rng->SelectSample(s); + const Hypergraph::Edge& edge = hg.edges_[hg.nodes_[n].in_edges_[i]]; + vector > ants(edge.tail_nodes_.size()); + for (int j = 0; j < ants.size(); ++j) + SampleRecurse(hg, ss, edge.tail_nodes_[j], &ants[j]); + + vector*> pants(ants.size()); + for (int j = 0; j < ants.size(); ++j) pants[j] = &ants[j]; + edge.rule_->ESubstitute(pants, out); + } + + struct SampleSort { + bool operator()(const pair& a, const pair& b) const { + return a.first > b.first; + } + }; + + // TODO this should be handled by an Observer + void MaxTranslationSample(Hypergraph* hg, const int samples, const int k) { + unordered_map > m; + hg->PushWeightsToGoal(); + const int num_nodes = hg->nodes_.size(); + vector > ss(num_nodes); + for (int i = 0; i < num_nodes; ++i) { + SampleSet& s = ss[i]; + const vector& in_edges = hg->nodes_[i].in_edges_; + for (int j = 0; j < in_edges.size(); ++j) { + s.add(hg->edges_[in_edges[j]].edge_prob_); + } + } + for (int i = 0; i < samples; ++i) { + vector yield; + SampleRecurse(*hg, ss, hg->nodes_.size() - 1, &yield); + const string trans = TD::GetString(yield); + ++m[trans]; + } + vector > dist; + for (unordered_map >::iterator i = m.begin(); + i != m.end(); ++i) { + dist.push_back(make_pair(i->second, i->first)); + } + sort(dist.begin(), dist.end(), SampleSort()); + if (k) { + for (int i = 0; i < k; ++i) + cout << dist[i].first << " ||| " << dist[i].second << endl; + } else { + cout << dist[0].second << endl; + } + } + + void ParseTranslatorInputLattice(const string& line, string* input, Lattice* ref) { + string sref; + ParseTranslatorInput(line, input, &sref); + if (sref.size() > 0) { + assert(ref); + LatticeTools::ConvertTextOrPLF(sref, ref); + } + } + + // used to construct the suffix string to get the name of arguments for multiple passes + // e.g., the "2" in --weights2 + static string StringSuffixForRescoringPass(int pass) { + if (pass == 0) return ""; + string ps = "1"; + assert(pass < 9); + ps[0] += pass; + return ps; + } + + vector rescoring_passes; + + po::variables_map& conf; + OracleBleu oracle; + string formalism; + boost::shared_ptr translator; + boost::shared_ptr > init_weights; // weights used with initial parse + vector > pffs; + boost::shared_ptr > rng; + int sample_max_trans; + bool aligner_mode; + bool graphviz; + bool joshua_viz; + bool encode_b64; + bool kbest; + bool unique_kbest; + bool get_oracle_forest; + boost::shared_ptr extract_file; + int combine_size; + int sent_id; + SparseVector acc_vec; // accumulate gradient + double acc_obj; // accumulate objective + int g_count; // number of gradient pieces computed + bool csplit_output_plf; + bool write_gradient; // TODO Observer + bool feature_expectations; // TODO Observer + bool output_training_vector; // TODO Observer + bool remove_intersected_rule_annotations; + bool mr_mira_compat; // Mr.MIRA compatibility mode. + boost::scoped_ptr incremental; + + + static void ConvertSV(const SparseVector& src, SparseVector* trg) { + for (SparseVector::const_iterator it = src.begin(); it != src.end(); ++it) + trg->set_value(it->first, it->second.as_float()); + } +}; + +DecoderImpl::~DecoderImpl() { + if (output_training_vector && !acc_vec.empty()) { + if (encode_b64) { + cout << "0\t"; + SparseVector dav; ConvertSV(acc_vec, &dav); + B64::Encode(acc_obj, dav, &cout); + cout << endl << flush; + } else { + cout << "0\t**OBJ**=" << acc_obj << ';' << acc_vec << endl << flush; + } + } +} + +DecoderImpl::DecoderImpl(po::variables_map& conf, int argc, char** argv, istream* cfg) : conf(conf) { + if (cfg) { if (argc || argv) { cerr << "DecoderImpl() can only take a file or command line options, not both\n"; exit(1); } } + bool show_config; + bool show_weights; + vector cfg_files; + + po::options_description opts("Configuration options"); + opts.add_options() + ("formalism,f",po::value(),"Decoding formalism; values include SCFG, FST, PB, LexTrans (lexical translation model, also disc training), CSplit (compound splitting), Tagger (sequence labeling), LexAlign (alignment only, or EM training)") + ("input,i",po::value()->default_value("-"),"Source file") + ("grammar,g",po::value >()->composing(),"Either SCFG grammar file(s) or phrase tables file(s)") + ("per_sentence_grammar_file", po::value(), "Optional (and possibly not implemented) per sentence grammar file enables all per sentence grammars to be stored in a single large file and accessed by offset") + ("list_feature_functions,L","List available feature functions") +#ifdef HAVE_CMPH + ("cmph_perfect_feature_hash,h", po::value(), "Load perfect hash function for features") +#endif + + ("weights,w",po::value(),"Feature weights file (initial forest / pass 1)") + ("feature_function,F",po::value >()->composing(), "Pass 1 additional feature function(s) (-L for list)") + ("intersection_strategy,I",po::value()->default_value("cube_pruning"), "Pass 1 intersection strategy for incorporating finite-state features; values include Cube_pruning, Full, Fast_cube_pruning, Fast_cube_pruning_2") + ("cubepruning_pop_limit,K",po::value()->default_value(200), "Max number of pops from the candidate heap at each node") + ("summary_feature", po::value(), "Compute a 'summary feature' at the end of the pass (before any pruning) with name=arg and value=inside-outside/Z") + ("summary_feature_type", po::value()->default_value("node_risk"), "Summary feature types: node_risk, edge_risk, edge_prob") + ("density_prune", po::value(), "Pass 1 pruning: keep no more than this many times the number of edges used in the best derivation tree (>=1.0)") + ("beam_prune", po::value(), "Pass 1 pruning: Prune paths from scored forest, keep paths within exp(alpha>=0)") + + ("weights2",po::value(),"Optional pass 2") + ("feature_function2",po::value >()->composing(), "Optional pass 2") + ("intersection_strategy2",po::value()->default_value("cube_pruning"), "Optional pass 2") + ("cubepruning_pop_limit2",po::value()->default_value(200), "Optional pass 2") + ("summary_feature2", po::value(), "Optional pass 2") + ("density_prune2", po::value(), "Optional pass 2") + ("beam_prune2", po::value(), "Optional pass 2") + + ("weights3",po::value(),"Optional pass 3") + ("feature_function3",po::value >()->composing(), "Optional pass 3") + ("intersection_strategy3",po::value()->default_value("cube_pruning"), "Optional pass 3") + ("cubepruning_pop_limit3",po::value()->default_value(200), "Optional pass 3") + ("summary_feature3", po::value(), "Optional pass 3") + ("density_prune3", po::value(), "Optional pass 3") + ("beam_prune3", po::value(), "Optional pass 3") + + ("add_pass_through_rules,P","Add rules to translate OOV words as themselves") + ("add_extra_pass_through_features,Q", po::value()->default_value(0), "Add PassThrough{1..N} features, capped at N.") + ("k_best,k",po::value(),"Extract the k best derivations") + ("unique_k_best,r", "Unique k-best translation list") + ("aligner,a", "Run as a word/phrase aligner (src & ref required)") + ("aligner_use_viterbi", "If run in alignment mode, compute the Viterbi (rather than MAP) alignment") + ("goal",po::value()->default_value("S"),"Goal symbol (SCFG & FST)") + ("freeze_feature_set,Z", "Freeze feature set after reading feature weights file") + ("warn_0_weight","Warn about any feature id that has a 0 weight (this is perfectly safe if you intend 0 weight, though)") + ("scfg_extra_glue_grammar", po::value(), "Extra glue grammar file (Glue grammars apply when i=0 but have no other span restrictions)") + ("scfg_no_hiero_glue_grammar,n", "No Hiero glue grammar (nb. by default the SCFG decoder adds Hiero glue rules)") + ("scfg_default_nt,d",po::value()->default_value("X"),"Default non-terminal symbol in SCFG") + ("scfg_max_span_limit,S",po::value()->default_value(10),"Maximum non-terminal span limit (except \"glue\" grammar)") + ("quiet", "Disable verbose output") + ("show_config", po::bool_switch(&show_config), "show contents of loaded -c config files.") + ("show_weights", po::bool_switch(&show_weights), "show effective feature weights") + ("show_feature_dictionary", "After decoding the last input, write the contents of the feature dictionary") + ("show_joshua_visualization,J", "Produce output compatible with the Joshua visualization tools") + ("show_tree_structure", "Show the Viterbi derivation structure") + ("show_expected_length", "Show the expected translation length under the model") + ("show_partition,z", "Compute and show the partition (inside score)") + ("show_conditional_prob", "Output the conditional log prob to STDOUT instead of a translation") + ("show_cfg_search_space", "Show the search space as a CFG") + ("show_cfg_alignment_space", "Show the alignment hypergraph as a CFG") + ("show_target_graph", po::value(), "Directory to write the target hypergraphs to") + ("incremental_search", po::value(), "Run lazy search with this language model file") + ("coarse_to_fine_beam_prune", po::value(), "Prune paths from coarse parse forest before fine parse, keeping paths within exp(alpha>=0)") + ("ctf_beam_widen", po::value()->default_value(2.0), "Expand coarse pass beam by this factor if no fine parse is found") + ("ctf_num_widenings", po::value()->default_value(2), "Widen coarse beam this many times before backing off to full parse") + ("ctf_no_exhaustive", "Do not fall back to exhaustive parse if coarse-to-fine parsing fails") + ("scale_prune_srclen", "scale beams by the input length (in # of tokens; may not be what you want for lattices") + ("lextrans_dynasearch", "'DynaSearch' neighborhood instead of usual partition, as defined by Smith & Eisner (2005)") + ("lextrans_use_null", "Support source-side null words in lexical translation") + ("lextrans_align_only", "Only used in alignment mode. Limit target words generated by reference") + ("tagger_tagset,t", po::value(), "(Tagger) file containing tag set") + ("csplit_output_plf", "(Compound splitter) Output lattice in PLF format") + ("csplit_preserve_full_word", "(Compound splitter) Always include the unsegmented form in the output lattice") + ("extract_rules", po::value(), "Extract the rules used in translation (not de-duped!) to a file in this directory") + ("show_derivations", po::value(), "Directory to print the derivation structures to") + ("show_derivations_mask", po::value()->default_value(Hypergraph::SPAN|Hypergraph::RULE), "Bit-mask for what to print in derivation structures") + ("graphviz","Show (constrained) translation forest in GraphViz format") + ("max_translation_beam,x", po::value(), "Beam approximation to get max translation from the chart") + ("max_translation_sample,X", po::value(), "Sample the max translation from the chart") + ("pb_max_distortion,D", po::value()->default_value(4), "Phrase-based decoder: maximum distortion") + ("cll_gradient,G","Compute conditional log-likelihood gradient and write to STDOUT (src & ref required)") + ("get_oracle_forest,o", "Calculate rescored hypergraph using approximate BLEU scoring of rules") + ("feature_expectations","Write feature expectations for all features in chart (**OBJ** will be the partition)") + ("vector_format",po::value()->default_value("b64"), "Sparse vector serialization format for feature expectations or gradients, includes (text or b64)") + ("combine_size,C",po::value()->default_value(1), "When option -G is used, process this many sentence pairs before writing the gradient (1=emit after every sentence pair)") + ("forest_output,O",po::value(),"Directory to write forests to") + ("remove_intersected_rule_annotations", "After forced decoding is completed, remove nonterminal annotations (i.e., the source side spans)") + ("mr_mira_compat", "Mr.MIRA compatibility mode (applies weight delta if available; outputs number of lines before k-best)"); + + // ob.AddOptions(&opts); + po::options_description clo("Command line options"); + clo.add_options() + ("config,c", po::value >(&cfg_files), "Configuration file(s) - latest has priority") + ("help,?", "Print this help message and exit") + ("usage,u", po::value(), "Describe a feature function type") + ("compgen", "Print just option names suitable for bash command line completion builtin 'compgen'") + ; + + po::options_description dconfig_options, dcmdline_options; + dconfig_options.add(opts); + + dcmdline_options.add(dconfig_options).add(clo); + if (argc) { + po::store(parse_command_line(argc, argv, dcmdline_options), conf); + if (conf.count("compgen")) { + print_options(cout,dcmdline_options); + cout << endl; + exit(0); + } + if (conf.count("quiet")) + SetSilent(true); + if (!SILENT) ShowBanner(); + } + if (conf.count("show_config")) // special handling needed because we only want to notify() once. + show_config=true; + if (conf.count("config") && !cfg) { + typedef vector Cs; + Cs cs=conf["config"].as(); + for (int i=0;i() << " ...\n"; + FD::EnableHash(conf["cmph_perfect_feature_hash"].as()); + cerr << " " << FD::NumFeats() << " features in map\n"; + } + + // load initial feature weights (and possibly freeze feature set) + init_weights.reset(new vector); + if (conf.count("weights")) + Weights::InitFromFile(str("weights",conf), init_weights.get()); + + if (conf.count("extract_rules")) { + if (!DirectoryExists(conf["extract_rules"].as())) + MkDirP(conf["extract_rules"].as()); + } + + // determine the number of rescoring/pruning/weighting passes configured + const int MAX_PASSES = 3; + for (int pass = 0; pass < MAX_PASSES; ++pass) { + string ws = "weights" + StringSuffixForRescoringPass(pass); + string ff = "feature_function" + StringSuffixForRescoringPass(pass); + string sf = "summary_feature" + StringSuffixForRescoringPass(pass); + string bp = "beam_prune" + StringSuffixForRescoringPass(pass); + string dp = "density_prune" + StringSuffixForRescoringPass(pass); + bool first_pass_condition = ((pass == 0) && (conf.count(ff) || conf.count(bp) || conf.count(dp))); + bool nth_pass_condition = ((pass > 0) && (conf.count(ws) || conf.count(ff) || conf.count(bp) || conf.count(dp))); + if (first_pass_condition || nth_pass_condition) { + rescoring_passes.push_back(RescoringPass()); + RescoringPass& rp = rescoring_passes.back(); + // only configure new weights if pass > 0, otherwise we reuse the initial chart weights + if (nth_pass_condition && conf.count(ws)) { + rp.weight_vector.reset(new vector()); + Weights::InitFromFile(str(ws.c_str(), conf), rp.weight_vector.get()); + } + bool has_stateful = false; + if (conf.count(ff)) { + vector add_ffs; + store_conf(conf,ff,&add_ffs); + for (int i = 0; i < add_ffs.size(); ++i) { + pffs.push_back(make_ff(add_ffs[i],verbose_feature_functions)); + FeatureFunction const* p=pffs.back().get(); + rp.ffs.push_back(p); + if (p->IsStateful()) { has_stateful = true; } + } + } + if (conf.count(sf)) { + rp.fid_summary = FD::Convert(conf[sf].as()); + assert(rp.fid_summary > 0); + // TODO assert that weights for this pass have coef(fid_summary) == 0.0? + } + if (conf.count(bp)) { rp.beam_prune = conf[bp].as(); } + if (conf.count(dp)) { rp.density_prune = conf[dp].as(); } + int palg = (has_stateful ? 1 : 0); // if there are no stateful featueres, default to FULL + string isn = "intersection_strategy" + StringSuffixForRescoringPass(pass); + string spl = "cubepruning_pop_limit" + StringSuffixForRescoringPass(pass); + unsigned pop_limit = 200; + if (conf.count(spl)) { pop_limit = conf[spl].as(); } + if (LowercaseString(str(isn.c_str(),conf)) == "full") { + palg = 0; + } + if (LowercaseString(conf["intersection_strategy"].as()) == "fast_cube_pruning") { + palg = 2; + cerr << "Using Fast Cube Pruning intersection (see Algorithm 2 described in: Gesmundo A., Henderson J,. Faster Cube Pruning, IWSLT 2010).\n"; + } + if (LowercaseString(conf["intersection_strategy"].as()) == "fast_cube_pruning_2") { + palg = 3; + cerr << "Using Fast Cube Pruning 2 intersection (see Algorithm 3 described in: Gesmundo A., Henderson J,. Faster Cube Pruning, IWSLT 2010).\n"; + } + rp.inter_conf.reset(new IntersectionConfiguration(palg, pop_limit)); + } else { + break; // TODO alert user if there are any future configurations + } + } + + // set up weight vectors since later phases may reuse weights from earlier phases + boost::shared_ptr > prev_weights = init_weights; + for (int pass = 0; pass < rescoring_passes.size(); ++pass) { + RescoringPass& rp = rescoring_passes[pass]; + if (!rp.weight_vector) { + rp.weight_vector = prev_weights; + } else { + prev_weights = rp.weight_vector; + } + rp.models.reset(new ModelSet(*rp.weight_vector, rp.ffs)); + } + + // show configuration of rescoring passes + if (!SILENT) { + int num = rescoring_passes.size(); + cerr << "Configured " << num << " rescoring pass" << (num == 1 ? "" : "es") << endl; + for (int pass = 0; pass < num; ++pass) + cerr << " " << rescoring_passes[pass] << endl; + } + + bool warn0=conf.count("warn_0_weight"); + bool freeze=conf.count("freeze_feature_set"); + bool early_freeze=freeze && !warn0; + bool late_freeze=freeze && warn0; + if (early_freeze) { + cerr << "Freezing feature set" << endl; + FD::Freeze(); // this means we can't see the feature names of not-weighted features + } + + // set up translation back end + if (formalism == "scfg") + translator.reset(new SCFGTranslator(conf)); + else if (formalism == "t2s") + translator.reset(new Tree2StringTranslator(conf, false)); + else if (formalism == "t2t") + translator.reset(new Tree2StringTranslator(conf, true)); + else if (formalism == "fst") + translator.reset(new FSTTranslator(conf)); + else if (formalism == "pb") + translator.reset(new PhraseBasedTranslator(conf)); + else if (formalism == "csplit") + translator.reset(new CompoundSplit(conf)); + else if (formalism == "lextrans") + translator.reset(new LexicalTrans(conf)); + else if (formalism == "lexalign") + translator.reset(new LexicalAlign(conf)); + else if (formalism == "rescore") + translator.reset(new RescoreTranslator(conf)); + else if (formalism == "tagger") + translator.reset(new Tagger(conf)); + else + assert(!"error"); + + if (late_freeze) { + cerr << "Late freezing feature set (use --no_freeze_feature_set to prevent)." << endl; + FD::Freeze(); // this means we can't see the feature names of not-weighted features + } + + sample_max_trans = conf.count("max_translation_sample") ? + conf["max_translation_sample"].as() : 0; + if (sample_max_trans) + rng.reset(new RandomNumberGenerator); + aligner_mode = conf.count("aligner"); + graphviz = conf.count("graphviz"); + joshua_viz = conf.count("show_joshua_visualization"); + encode_b64 = str("vector_format",conf) == "b64"; + kbest = conf.count("k_best"); + unique_kbest = conf.count("unique_k_best"); + get_oracle_forest = conf.count("get_oracle_forest"); + oracle.show_derivation=conf.count("show_derivations"); + oracle.show_derivation_mask=conf["show_derivations_mask"].as(); + remove_intersected_rule_annotations = conf.count("remove_intersected_rule_annotations"); + mr_mira_compat = conf.count("mr_mira_compat"); + + combine_size = conf["combine_size"].as(); + if (combine_size < 1) combine_size = 1; + sent_id = -1; + acc_obj = 0; // accumulate objective + g_count = 0; // number of gradient pieces computed + + if (conf.count("incremental_search")) { + incremental.reset(IncrementalBase::Load(conf["incremental_search"].as().c_str(), CurrentWeightVector())); + } +} + +Decoder::Decoder(istream* cfg) { pimpl_.reset(new DecoderImpl(conf,0,0,cfg)); } +Decoder::Decoder(int argc, char** argv) { pimpl_.reset(new DecoderImpl(conf,argc, argv, 0)); } +Decoder::~Decoder() {} +void Decoder::SetId(int next_sent_id) { pimpl_->SetId(next_sent_id); } +bool Decoder::Decode(const string& input, DecoderObserver* o) { + bool del = false; + if (!o) { o = new DecoderObserver; del = true; } + const bool res = pimpl_->Decode(input, o); + if (del) delete o; + return res; +} +vector& Decoder::CurrentWeightVector() { return pimpl_->CurrentWeightVector(); } +const vector& Decoder::CurrentWeightVector() const { return pimpl_->CurrentWeightVector(); } +void Decoder::AddSupplementalGrammar(GrammarPtr gp) { + static_cast(*pimpl_->translator).AddSupplementalGrammar(gp); +} +void Decoder::AddSupplementalGrammarFromString(const std::string& grammar_string) { + assert(pimpl_->translator->GetDecoderType() == "SCFG"); + static_cast(*pimpl_->translator).AddSupplementalGrammarFromString(grammar_string); +} + +static inline void ApplyWeightDelta(const string &delta_b64, vector *weights) { + SparseVector delta; + DecodeFeatureVector(delta_b64, &delta); + if (delta.empty()) return; + // Apply updates + for (SparseVector::iterator dit = delta.begin(); + dit != delta.end(); ++dit) { + int feat_id = dit->first; + union { weight_t weight; unsigned long long repr; } feat_delta; + feat_delta.weight = dit->second; + if (!SILENT) + cerr << "[decoder weight update] " << FD::Convert(feat_id) << " " << feat_delta.weight + << " = " << hex << feat_delta.repr << endl; + if (weights->size() <= feat_id) weights->resize(feat_id + 1); + (*weights)[feat_id] += feat_delta.weight; + } +} + +bool DecoderImpl::Decode(const string& input, DecoderObserver* o) { + string buf = input; + NgramCache::Clear(); // clear ngram cache for remote LM (if used) + Timer::Summarize(); + ++sent_id; + map sgml; + ProcessAndStripSGML(&buf, &sgml); + if (sgml.find("id") != sgml.end()) + sent_id = atoi(sgml["id"].c_str()); + + // Add delta from input to weights before decoding + if (mr_mira_compat) + ApplyWeightDelta(sgml["delta"], init_weights.get()); + + if (!SILENT) { + cerr << "\nINPUT: "; + if (buf.size() < 100) + cerr << buf << endl; + else { + size_t x = buf.rfind(" ", 100); + if (x == string::npos) x = 100; + cerr << buf.substr(0, x) << " ..." << endl; + } + cerr << " id = " << sent_id << endl; + } + if (conf.count("extract_rules")) { + stringstream ss; + ss << sent_id << ".gz"; + extract_file.reset(new WriteFile(str("extract_rules",conf)+"/"+ss.str())); + } + string to_translate; + Lattice ref; + ParseTranslatorInputLattice(buf, &to_translate, &ref); + const unsigned srclen=NTokens(to_translate,' '); +//FIXME: should get the avg. or max source length of the input lattice (like Lattice::dist_(start,end)); but this is only used to scale beam parameters (optionally) anyway so fidelity isn't important. + const bool has_ref = ref.size() > 0; + SentenceMetadata smeta(sent_id, ref); + smeta.sgml_.swap(sgml); + o->NotifyDecodingStart(smeta); + Hypergraph forest; // -LM forest + translator->ProcessMarkupHints(smeta.sgml_); + Timer t("Translation"); + const bool translation_successful = + translator->Translate(to_translate, &smeta, *init_weights, &forest); + translator->SentenceComplete(); + + if (!translation_successful) { + if (!SILENT) { cerr << " NO PARSE FOUND.\n"; } + o->NotifySourceParseFailure(smeta); + o->NotifyDecodingComplete(smeta); + if (conf.count("show_conditional_prob")) { + cout << "-Inf" << endl << flush; + } else if (!SILENT) { + cout << endl; + } + return false; + } + + // this is mainly used for debugging, eventually this will be an assertion + if (!forest.AreNodesUniquelyIdentified()) { + if (!SILENT) cerr << " *** NODES NOT UNIQUELY IDENTIFIED ***\n"; + } + + if (!forest.ArePreGoalEdgesArity1()) { + cerr << "Pre-goal edges are not arity-1. The decoder requires this.\n"; + abort(); + } + + const bool show_tree_structure=conf.count("show_tree_structure"); + if (!SILENT) forest_stats(forest," Init. forest",show_tree_structure,oracle.show_derivation); + if (conf.count("show_expected_length")) { + const PRPair res = + Inside, + PRWeightFunction >(forest); + cerr << " Expected length (words): " << (res.r / res.p).as_float() << "\t" << res << endl; + } + + if (conf.count("show_partition")) { + const prob_t z = Inside(forest); + cerr << " Partition log(Z): " << log(z) << endl; + } + + SummaryFeature summary_feature_type = kNODE_RISK; + if (conf["summary_feature_type"].as() == "edge_risk") + summary_feature_type = kEDGE_RISK; + else if (conf["summary_feature_type"].as() == "node_risk") + summary_feature_type = kNODE_RISK; + else if (conf["summary_feature_type"].as() == "edge_prob") + summary_feature_type = kEDGE_PROB; + else { + cerr << "Bad summary_feature_type: " << conf["summary_feature_type"].as() << endl; + abort(); + } + + if (conf.count("show_target_graph")) { + HypergraphIO::WriteTarget(conf["show_target_graph"].as(), sent_id, forest); + } + if (conf.count("incremental_search")) { + incremental->Search(conf["cubepruning_pop_limit"].as(), forest); + } + if (conf.count("show_target_graph") || conf.count("incremental_search")) { + o->NotifyDecodingComplete(smeta); + return true; + } + + for (int pass = 0; pass < rescoring_passes.size(); ++pass) { + const RescoringPass& rp = rescoring_passes[pass]; + const vector& cur_weights = *rp.weight_vector; + if (!SILENT) cerr << endl << " RESCORING PASS #" << (pass+1) << " " << rp << endl; + + string passtr = "Pass1"; passtr[4] += pass; + forest.Reweight(cur_weights); + const bool has_rescoring_models = !rp.models->empty(); + if (has_rescoring_models) { + Timer t("Forest rescoring:"); + rp.models->PrepareForInput(smeta); + Hypergraph rescored_forest; +#ifdef CP_TIME + CpTime::Sub(clock()); +#endif + ApplyModelSet(forest, + smeta, + *rp.models, + *rp.inter_conf, + &rescored_forest); +#ifdef CP_TIME + CpTime::Add(clock()); +#endif + forest.swap(rescored_forest); + forest.Reweight(cur_weights); + if (!SILENT) forest_stats(forest," " + passtr +" forest",show_tree_structure,oracle.show_derivation, conf.count("extract_rules"), extract_file); + // this is mainly used for debugging, eventually this will be an assertion + if (!forest.AreNodesUniquelyIdentified()) { + if (!SILENT) cerr << " *** NODES NOT UNIQUELY IDENTIFIED ***\n"; + } + } + + if (conf.count("show_partition")) { + const prob_t z = Inside(forest); + cerr << " " << passtr << " partition log(Z): " << log(z) << endl; + } + + if (rp.fid_summary) { + if (summary_feature_type == kEDGE_PROB) { + const prob_t z = forest.PushWeightsToGoal(1.0); + if (!std::isfinite(log(z)) || std::isnan(log(z))) { + cerr << " " << passtr << " !!! Invalid partition detected, abandoning.\n"; + } else { + for (int i = 0; i < forest.edges_.size(); ++i) { + const double log_prob_transition = log(forest.edges_[i].edge_prob_); // locally normalized by the edge + // head node by forest.PushWeightsToGoal + if (!std::isfinite(log_prob_transition) || std::isnan(log_prob_transition)) { + cerr << "Edge: i=" << i << " got bad inside prob: " << *forest.edges_[i].rule_ << endl; + abort(); + } + + forest.edges_[i].feature_values_.set_value(rp.fid_summary, log_prob_transition); + } + forest.Reweight(cur_weights); // reset weights + } + } else if (summary_feature_type == kNODE_RISK) { + Hypergraph::EdgeProbs posts; + const prob_t z = forest.ComputeEdgePosteriors(1.0, &posts); + if (!std::isfinite(log(z)) || std::isnan(log(z))) { + cerr << " " << passtr << " !!! Invalid partition detected, abandoning.\n"; + } else { + for (int i = 0; i < forest.nodes_.size(); ++i) { + const Hypergraph::EdgesVector& in_edges = forest.nodes_[i].in_edges_; + prob_t node_post = prob_t(0); + for (int j = 0; j < in_edges.size(); ++j) + node_post += (posts[in_edges[j]] / z); + const double log_np = log(node_post); + if (!std::isfinite(log_np) || std::isnan(log_np)) { + cerr << "got bad posterior prob for node " << i << endl; + abort(); + } + for (int j = 0; j < in_edges.size(); ++j) + forest.edges_[in_edges[j]].feature_values_.set_value(rp.fid_summary, exp(log_np)); +// Hypergraph::Edge& example_edge = forest.edges_[in_edges[0]]; +// string n = "NONE"; +// if (forest.nodes_[i].cat_) n = TD::Convert(-forest.nodes_[i].cat_); +// cerr << "[" << n << "," << example_edge.i_ << "," << example_edge.j_ << "] = " << exp(log_np) << endl; + } + } + } else if (summary_feature_type == kEDGE_RISK) { + Hypergraph::EdgeProbs posts; + const prob_t z = forest.ComputeEdgePosteriors(1.0, &posts); + if (!std::isfinite(log(z)) || std::isnan(log(z))) { + cerr << " " << passtr << " !!! Invalid partition detected, abandoning.\n"; + } else { + assert(posts.size() == forest.edges_.size()); + for (int i = 0; i < posts.size(); ++i) { + const double log_np = log(posts[i] / z); + if (!std::isfinite(log_np) || std::isnan(log_np)) { + cerr << "got bad posterior prob for node " << i << endl; + abort(); + } + forest.edges_[i].feature_values_.set_value(rp.fid_summary, exp(log_np)); + } + } + } else { + assert(!"shouldn't happen"); + } + } + + string fullbp = "beam_prune" + StringSuffixForRescoringPass(pass); + string fulldp = "density_prune" + StringSuffixForRescoringPass(pass); + maybe_prune(forest,conf,fullbp.c_str(),fulldp.c_str(),passtr,srclen); + } + + const vector& last_weights = (rescoring_passes.empty() ? *init_weights : *rescoring_passes.back().weight_vector); + + // Oracle Rescoring + if(get_oracle_forest) { + assert(!"this is broken"); SparseVector dummy; // = last_weights + Oracle oc=oracle.ComputeOracle(smeta,&forest,dummy,10,conf["forest_output"].as()); + if (!SILENT) cerr << " +Oracle BLEU forest (nodes/edges): " << forest.nodes_.size() << '/' << forest.edges_.size() << endl; + if (!SILENT) cerr << " +Oracle BLEU (paths): " << forest.NumberOfPaths() << endl; + oc.hope.Print(cerr," +Oracle BLEU"); + oc.fear.Print(cerr," -Oracle BLEU"); + //Add 1-best translation (trans) to psuedo-doc vectors + if (!SILENT) oracle.IncludeLastScore(&cerr); + } + o->NotifyTranslationForest(smeta, &forest); + + // TODO I think this should probably be handled by an Observer + if (conf.count("forest_output") && !has_ref) { + ForestWriter writer(str("forest_output",conf), sent_id); + if (FileExists(writer.fname_)) { + if (!SILENT) cerr << " Unioning...\n"; + Hypergraph new_hg; + { + ReadFile rf(writer.fname_); + bool succeeded = HypergraphIO::ReadFromBinary(rf.stream(), &new_hg); + if (!succeeded) abort(); + } + HG::Union(forest, &new_hg); + bool succeeded = writer.Write(new_hg); + if (!succeeded) abort(); + } else { + bool succeeded = writer.Write(forest); + if (!succeeded) abort(); + } + } + + // TODO I think this should probably be handled by an Observer + if (sample_max_trans) { + MaxTranslationSample(&forest, sample_max_trans, conf.count("k_best") ? conf["k_best"].as() : 0); + } else { + if (kbest && !has_ref) { + //TODO: does this work properly? + const string deriv_fname = conf.count("show_derivations") ? str("show_derivations",conf) : "-"; + oracle.DumpKBest(sent_id, forest, conf["k_best"].as(), unique_kbest,mr_mira_compat, smeta.GetSourceLength(), "-", deriv_fname); + } else if (csplit_output_plf) { + cout << HypergraphIO::AsPLF(forest, false) << endl; + } else { + if (!graphviz && !has_ref && !joshua_viz && !SILENT) { + vector trans; + ViterbiESentence(forest, &trans); + cout << TD::GetString(trans) << endl << flush; + } + if (joshua_viz) { + cout << sent_id << " ||| " << JoshuaVisualizationString(forest) << " ||| 1.0 ||| " << -1.0 << endl << flush; + } + } + } + + prob_t first_z; + if (conf.count("show_conditional_prob")) { + first_z = Inside(forest); + } + + // TODO this should be handled by an Observer + const int max_trans_beam_size = conf.count("max_translation_beam") ? + conf["max_translation_beam"].as() : 0; + if (max_trans_beam_size) { + Hack::MaxTrans(forest, max_trans_beam_size); + return true; + } + + // TODO this should be handled by an Observer + if (graphviz && !has_ref) forest.PrintGraphviz(); + + // the following are only used if write_gradient is true! + SparseVector full_exp, ref_exp, gradient; + double log_z = 0, log_ref_z = 0; + if (write_gradient) { + const prob_t z = InsideOutside, EdgeFeaturesAndProbWeightFunction>(forest, &full_exp); + log_z = log(z); + full_exp /= z; + } + if (conf.count("show_cfg_search_space")) + HypergraphIO::WriteAsCFG(forest); + if (has_ref) { + if (HG::Intersect(ref, &forest)) { +// if (crf_uniform_empirical) { +// if (!SILENT) cerr << " USING UNIFORM WEIGHTS\n"; +// for (int i = 0; i < forest.edges_.size(); ++i) +// forest.edges_[i].edge_prob_=prob_t::One(); } + if (remove_intersected_rule_annotations) { + for (unsigned i = 0; i < forest.edges_.size(); ++i) + if (forest.edges_[i].rule_ && + forest.edges_[i].rule_->parent_rule_) + forest.edges_[i].rule_ = forest.edges_[i].rule_->parent_rule_; + } + forest.Reweight(last_weights); + // this is mainly used for debugging, eventually this will be an assertion + if (!forest.AreNodesUniquelyIdentified()) { + if (!SILENT) cerr << " *** NODES NOT UNIQUELY IDENTIFIED ***\n"; + } + if (!SILENT) forest_stats(forest," Constr. forest",show_tree_structure,oracle.show_derivation); + if (!SILENT) cerr << " Constr. VitTree: " << ViterbiFTree(forest) << endl; + if (conf.count("show_partition")) { + const prob_t z = Inside(forest); + cerr << " Contst. partition log(Z): " << log(z) << endl; + } + o->NotifyAlignmentForest(smeta, &forest); + if (conf.count("show_cfg_alignment_space")) + HypergraphIO::WriteAsCFG(forest); + if (conf.count("forest_output")) { + ForestWriter writer(str("forest_output",conf), sent_id); + if (FileExists(writer.fname_)) { + if (!SILENT) cerr << " Unioning...\n"; + Hypergraph new_hg; + { + ReadFile rf(writer.fname_); + bool succeeded = HypergraphIO::ReadFromBinary(rf.stream(), &new_hg); + if (!succeeded) abort(); + } + HG::Union(forest, &new_hg); + bool succeeded = writer.Write(new_hg); + if (!succeeded) abort(); + } else { + bool succeeded = writer.Write(forest); + if (!succeeded) abort(); + } + } + if (aligner_mode && !output_training_vector) + AlignerTools::WriteAlignment(smeta.GetSourceLattice(), smeta.GetReference(), forest, &cout, 0 == conf.count("aligner_use_viterbi"), kbest ? conf["k_best"].as() : 0); + if (write_gradient) { + const prob_t ref_z = InsideOutside, EdgeFeaturesAndProbWeightFunction>(forest, &ref_exp); + ref_exp /= ref_z; +// if (crf_uniform_empirical) +// log_ref_z = ref_exp.dot(last_weights); + log_ref_z = log(ref_z); + //cerr << " MODEL LOG Z: " << log_z << endl; + //cerr << " EMPIRICAL LOG Z: " << log_ref_z << endl; + if ((log_z - log_ref_z) < kMINUS_EPSILON) { + cerr << "DIFF. ERR! log_z < log_ref_z: " << log_z << " " << log_ref_z << endl; + exit(1); + } + assert(!std::isnan(log_ref_z)); + ref_exp -= full_exp; + acc_vec += ref_exp; + acc_obj += (log_z - log_ref_z); + } + if (feature_expectations) { + const prob_t z = + InsideOutside, EdgeFeaturesAndProbWeightFunction>(forest, &ref_exp); + ref_exp /= z; + acc_obj += log(z); + acc_vec += ref_exp; + } + + if (output_training_vector) { + acc_vec.erase(0); + ++g_count; + if (g_count % combine_size == 0) { + if (encode_b64) { + cout << "0\t"; + SparseVector dav; ConvertSV(acc_vec, &dav); + B64::Encode(acc_obj, dav, &cout); + cout << endl << flush; + } else { + cout << "0\t**OBJ**=" << acc_obj << ';' << acc_vec << endl << flush; + } + acc_vec.clear(); + acc_obj = 0; + } + } + if (conf.count("graphviz")) forest.PrintGraphviz(); + if (kbest) { + const string deriv_fname = conf.count("show_derivations") ? str("show_derivations",conf) : "-"; + oracle.DumpKBest(sent_id, forest, conf["k_best"].as(), unique_kbest, mr_mira_compat, smeta.GetSourceLength(), "-", deriv_fname); + } + if (conf.count("show_conditional_prob")) { + const prob_t ref_z = Inside(forest); + cout << (log(ref_z) - log(first_z)) << endl << flush; + } + } else { + o->NotifyAlignmentFailure(smeta); + if (!SILENT) cerr << " REFERENCE UNREACHABLE.\n"; + if (write_gradient) { + cout << endl << flush; + } + if (conf.count("show_conditional_prob")) { + cout << "-Inf" << endl << flush; + } + } + } + o->NotifyDecodingComplete(smeta); + return true; +} diff --git a/decoder/decoder.h b/decoder/decoder.h new file mode 100644 index 000000000..6250d1eb3 --- /dev/null +++ b/decoder/decoder.h @@ -0,0 +1,69 @@ +#ifndef DECODER_H_ +#define DECODER_H_ + +#include +#include +#include +#include +#include + +#include "weights.h" // weight_t + +#undef CP_TIME +//#define CP_TIME +#ifdef CP_TIME +#include +struct CpTime{ +public: + static void Add(clock_t x); + static void Sub(clock_t x); + static double Get(); +private: + static clock_t time_; +}; +#endif + +class SentenceMetadata; +class Hypergraph; +struct DecoderImpl; + +class DecoderObserver { + public: + virtual ~DecoderObserver(); + virtual void NotifyDecodingStart(const SentenceMetadata& smeta); + virtual void NotifySourceParseFailure(const SentenceMetadata& smeta); + virtual void NotifyTranslationForest(const SentenceMetadata& smeta, Hypergraph* hg); + virtual void NotifyAlignmentFailure(const SentenceMetadata& semta); + virtual void NotifyAlignmentForest(const SentenceMetadata& smeta, Hypergraph* hg); + virtual void NotifyDecodingComplete(const SentenceMetadata& smeta); +}; + +struct Grammar; // TODO once the decoder interface is cleaned up, + // this should be somewhere else +class Decoder { + public: + Decoder(int argc, char** argv); + Decoder(std::istream* config_file); + bool Decode(const std::string& input, DecoderObserver* observer = NULL); + + // access this to either *read* or *write* to the decoder's last + // weight vector (i.e., the weights of the finest past) + std::vector& CurrentWeightVector(); + const std::vector& CurrentWeightVector() const; + + // this sets the current sentence ID + void SetId(int id); + ~Decoder(); + const boost::program_options::variables_map& GetConf() const { return conf; } + + // add grammar rules (currently only supported by SCFG decoders) + // that will be used on subsequent calls to Decode. rules should be in standard + // text format. This function does NOT read from a file. + void AddSupplementalGrammar(boost::shared_ptr gp); + void AddSupplementalGrammarFromString(const std::string& grammar_string); + private: + boost::program_options::variables_map conf; + boost::shared_ptr pimpl_; +}; + +#endif diff --git a/decoder/earley_composer.cc b/decoder/earley_composer.cc new file mode 100644 index 000000000..d47a69699 --- /dev/null +++ b/decoder/earley_composer.cc @@ -0,0 +1,761 @@ +#include "earley_composer.h" + +#include +#include +#include +#include +#ifndef HAVE_OLD_CPP +# include +# include +#else +# include +# include +namespace std { using std::tr1::unordered_map; using std::tr1::unordered_multiset; using std::tr1::unordered_set; } +#endif + +#include +#include +#include +#include "fast_lexical_cast.hpp" + +#include "phrasetable_fst.h" +#include "sparse_vector.h" +#include "tdict.h" +#include "hg.h" +#include "hg_remove_eps.h" + +using namespace std; + +// Define the following macro if you want to see lots of debugging output +// when you run the chart parser +#undef DEBUG_CHART_PARSER + +// A few constants used by the chart parser /////////////// +static const int kMAX_NODES = 2000000; +static const string kPHRASE_STRING = "X"; +static bool constants_need_init = true; +static WordID kUNIQUE_START; +static WordID kPHRASE; +static TRulePtr kX1X2; +static TRulePtr kX1; +static WordID kEPS; +static TRulePtr kEPSRule; + +static void InitializeConstants() { + if (constants_need_init) { + kPHRASE = TD::Convert(kPHRASE_STRING) * -1; + kUNIQUE_START = TD::Convert("S") * -1; + kX1X2.reset(new TRule("[X] ||| [X,1] [X,2] ||| [X,1] [X,2]")); + kX1.reset(new TRule("[X] ||| [X,1] ||| [X,1]")); + kEPSRule.reset(new TRule("[X] ||| ||| ")); + kEPS = TD::Convert(""); + constants_need_init = false; + } +} +//////////////////////////////////////////////////////////// + +TRulePtr CreateBinaryRule(int lhs, int rhs1, int rhs2) { + TRule* r = new TRule(*kX1X2); + r->lhs_ = lhs; + r->f_[0] = rhs1; + r->f_[1] = rhs2; + return TRulePtr(r); +} + +TRulePtr CreateUnaryRule(int lhs, int rhs1) { + TRule* r = new TRule(*kX1); + r->lhs_ = lhs; + r->f_[0] = rhs1; + return TRulePtr(r); +} + +TRulePtr CreateEpsilonRule(int lhs) { + TRule* r = new TRule(*kEPSRule); + r->lhs_ = lhs; + return TRulePtr(r); +} + +class EGrammarNode { + friend bool EarleyComposer::Compose(const Hypergraph& src_forest, Hypergraph* trg_forest); + friend void AddGrammarRule(const string& r, map* g); + public: +#ifdef DEBUG_CHART_PARSER + string hint; +#endif + EGrammarNode() : is_some_rule_complete(false), is_root(false) {} + const map& GetTerminals() const { return tptr; } + const map& GetNonTerminals() const { return ntptr; } + bool HasNonTerminals() const { return (!ntptr.empty()); } + bool HasTerminals() const { return (!tptr.empty()); } + bool RuleCompletes() const { + return (is_some_rule_complete || (ntptr.empty() && tptr.empty())); + } + bool GrammarContinues() const { + return !(ntptr.empty() && tptr.empty()); + } + bool IsRoot() const { + return is_root; + } + // these are the features associated with the rule from the start + // node up to this point. If you use these features, you must + // not Extend() this rule. + const SparseVector& GetCFGProductionFeatures() const { + return input_features; + } + + const EGrammarNode* Extend(const WordID& t) const { + if (t < 0) { + map::const_iterator it = ntptr.find(t); + if (it == ntptr.end()) return NULL; + return &it->second; + } else { + map::const_iterator it = tptr.find(t); + if (it == tptr.end()) return NULL; + return &it->second; + } + } + + private: + map tptr; + map ntptr; + SparseVector input_features; + bool is_some_rule_complete; + bool is_root; +}; +typedef map EGrammar; // indexed by the rule LHS + +// edges are immutable once created +struct Edge { +#ifdef DEBUG_CHART_PARSER + static int id_count; + const int id; +#endif + const WordID cat; // lhs side of rule proved/being proved + const EGrammarNode* const dot; // dot position + const FSTNode* const q; // start of span + const FSTNode* const r; // end of span + const Edge* const active_parent; // back pointer, NULL for PREDICT items + const Edge* const passive_parent; // back pointer, NULL for SCAN and PREDICT items + const TargetPhraseSet* const tps; // translations + boost::shared_ptr > features; // features from CFG rule + + bool IsPassive() const { + // when a rule is completed, this value will be set + return static_cast(features); + } + bool IsActive() const { return !IsPassive(); } + bool IsInitial() const { + return !(active_parent || passive_parent); + } + bool IsCreatedByScan() const { + return active_parent && !passive_parent && !dot->IsRoot(); + } + bool IsCreatedByPredict() const { + return dot->IsRoot(); + } + bool IsCreatedByComplete() const { + return active_parent && passive_parent; + } + + // constructor for PREDICT + Edge(WordID c, const EGrammarNode* d, const FSTNode* q_and_r) : +#ifdef DEBUG_CHART_PARSER + id(++id_count), +#endif + cat(c), dot(d), q(q_and_r), r(q_and_r), active_parent(NULL), passive_parent(NULL), tps(NULL) {} + Edge(WordID c, const EGrammarNode* d, const FSTNode* q_and_r, const Edge* act_parent) : +#ifdef DEBUG_CHART_PARSER + id(++id_count), +#endif + cat(c), dot(d), q(q_and_r), r(q_and_r), active_parent(act_parent), passive_parent(NULL), tps(NULL) {} + + // constructors for SCAN + Edge(WordID c, const EGrammarNode* d, const FSTNode* i, const FSTNode* j, + const Edge* act_par, const TargetPhraseSet* translations) : +#ifdef DEBUG_CHART_PARSER + id(++id_count), +#endif + cat(c), dot(d), q(i), r(j), active_parent(act_par), passive_parent(NULL), tps(translations) {} + + Edge(WordID c, const EGrammarNode* d, const FSTNode* i, const FSTNode* j, + const Edge* act_par, const TargetPhraseSet* translations, + const SparseVector& feats) : +#ifdef DEBUG_CHART_PARSER + id(++id_count), +#endif + cat(c), dot(d), q(i), r(j), active_parent(act_par), passive_parent(NULL), tps(translations), + features(new SparseVector(feats)) {} + + // constructors for COMPLETE + Edge(WordID c, const EGrammarNode* d, const FSTNode* i, const FSTNode* j, + const Edge* act_par, const Edge *pas_par) : +#ifdef DEBUG_CHART_PARSER + id(++id_count), +#endif + cat(c), dot(d), q(i), r(j), active_parent(act_par), passive_parent(pas_par), tps(NULL) { + assert(pas_par->IsPassive()); + assert(act_par->IsActive()); + } + + Edge(WordID c, const EGrammarNode* d, const FSTNode* i, const FSTNode* j, + const Edge* act_par, const Edge *pas_par, const SparseVector& feats) : +#ifdef DEBUG_CHART_PARSER + id(++id_count), +#endif + cat(c), dot(d), q(i), r(j), active_parent(act_par), passive_parent(pas_par), tps(NULL), + features(new SparseVector(feats)) { + assert(pas_par->IsPassive()); + assert(act_par->IsActive()); + } + + // constructor for COMPLETE query + Edge(const FSTNode* _r) : +#ifdef DEBUG_CHART_PARSER + id(0), +#endif + cat(0), dot(NULL), q(NULL), + r(_r), active_parent(NULL), passive_parent(NULL), tps(NULL) {} + // constructor for MERGE quere + Edge(const FSTNode* _q, int) : +#ifdef DEBUG_CHART_PARSER + id(0), +#endif + cat(0), dot(NULL), q(_q), + r(NULL), active_parent(NULL), passive_parent(NULL), tps(NULL) {} +}; +#ifdef DEBUG_CHART_PARSER +int Edge::id_count = 0; +#endif + +ostream& operator<<(ostream& os, const Edge& e) { + string type = "PREDICT"; + if (e.IsCreatedByScan()) + type = "SCAN"; + else if (e.IsCreatedByComplete()) + type = "COMPLETE"; + os << "[" +#ifdef DEBUG_CHART_PARSER + << '(' << e.id << ") " +#else + << '(' << &e << ") " +#endif + << "q=" << e.q << ", r=" << e.r + << ", cat="<< TD::Convert(e.cat*-1) << ", dot=" + << e.dot +#ifdef DEBUG_CHART_PARSER + << e.dot->hint +#endif + << (e.IsActive() ? ", Active" : ", Passive") + << ", " << type; +#ifdef DEBUG_CHART_PARSER + if (e.active_parent) { os << ", act.parent=(" << e.active_parent->id << ')'; } + if (e.passive_parent) { os << ", psv.parent=(" << e.passive_parent->id << ')'; } +#endif + if (e.tps) { os << ", tps=" << e.tps; } + return os << ']'; +} + +struct Traversal { + const Edge* const edge; // result from the active / passive combination + const Edge* const active; + const Edge* const passive; + Traversal(const Edge* me, const Edge* a, const Edge* p) : edge(me), active(a), passive(p) {} +}; + +struct UniqueTraversalHash { + size_t operator()(const Traversal* t) const { + size_t x = 5381; + x = ((x << 5) + x) ^ reinterpret_cast(t->active); + x = ((x << 5) + x) ^ reinterpret_cast(t->passive); + x = ((x << 5) + x) ^ t->edge->IsActive(); + return x; + } +}; + +struct UniqueTraversalEquals { + size_t operator()(const Traversal* a, const Traversal* b) const { + return (a->passive == b->passive && a->active == b->active && a->edge->IsActive() == b->edge->IsActive()); + } +}; + +struct UniqueEdgeHash { + size_t operator()(const Edge* e) const { + size_t x = 5381; + if (e->IsActive()) { + x = ((x << 5) + x) ^ reinterpret_cast(e->dot); + x = ((x << 5) + x) ^ reinterpret_cast(e->q); + x = ((x << 5) + x) ^ reinterpret_cast(e->r); + x = ((x << 5) + x) ^ static_cast(e->cat); + x += 13; + } else { // with passive edges, we don't care about the dot + x = ((x << 5) + x) ^ reinterpret_cast(e->q); + x = ((x << 5) + x) ^ reinterpret_cast(e->r); + x = ((x << 5) + x) ^ static_cast(e->cat); + } + return x; + } +}; + +struct UniqueEdgeEquals { + bool operator()(const Edge* a, const Edge* b) const { + if (a->IsActive() != b->IsActive()) return false; + if (a->IsActive()) { + return (a->cat == b->cat) && (a->dot == b->dot) && (a->q == b->q) && (a->r == b->r); + } else { + return (a->cat == b->cat) && (a->q == b->q) && (a->r == b->r); + } + } +}; + +struct REdgeHash { + size_t operator()(const Edge* e) const { + size_t x = 5381; + x = ((x << 5) + x) ^ reinterpret_cast(e->r); + return x; + } +}; + +struct REdgeEquals { + bool operator()(const Edge* a, const Edge* b) const { + return (a->r == b->r); + } +}; + +struct QEdgeHash { + size_t operator()(const Edge* e) const { + size_t x = 5381; + x = ((x << 5) + x) ^ reinterpret_cast(e->q); + return x; + } +}; + +struct QEdgeEquals { + bool operator()(const Edge* a, const Edge* b) const { + return (a->q == b->q); + } +}; + +struct EdgeQueue { + queue q; + EdgeQueue() {} + void clear() { while(!q.empty()) q.pop(); } + bool HasWork() const { return !q.empty(); } + const Edge* Next() { const Edge* res = q.front(); q.pop(); return res; } + void AddEdge(const Edge* s) { q.push(s); } +}; + +class EarleyComposerImpl { + public: + EarleyComposerImpl(WordID start_cat, const FSTNode& q_0) : start_cat_(start_cat), q_0_(&q_0) {} + + // returns false if the intersection is empty + bool Compose(const EGrammar& g, Hypergraph* forest) { + goal_node = NULL; + EGrammar::const_iterator sit = g.find(start_cat_); + forest->ReserveNodes(kMAX_NODES); + assert(sit != g.end()); + Edge* init = new Edge(start_cat_, &sit->second, q_0_); + if (!IncorporateNewEdge(init)) { + cerr << "Failed to create initial edge!\n"; + abort(); + } + while (exp_agenda.HasWork() || agenda.HasWork()) { + while(exp_agenda.HasWork()) { + const Edge* edge = exp_agenda.Next(); + FinishEdge(edge, forest); + } + if (agenda.HasWork()) { + const Edge* edge = agenda.Next(); +#ifdef DEBUG_CHART_PARSER + cerr << "processing (" << edge->id << ')' << endl; +#endif + if (edge->IsActive()) { + if (edge->dot->HasTerminals()) + DoScan(edge); + if (edge->dot->HasNonTerminals()) { + DoMergeWithPassives(edge); + DoPredict(edge, g); + } + } else { + DoComplete(edge); + } + } + } + if (goal_node) { + forest->PruneUnreachable(goal_node->id_); + RemoveEpsilons(forest, kEPS); + } + FreeAll(); + return goal_node; + } + + void FreeAll() { + for (int i = 0; i < free_list_.size(); ++i) + delete free_list_[i]; + free_list_.clear(); + for (int i = 0; i < traversal_free_list_.size(); ++i) + delete traversal_free_list_[i]; + traversal_free_list_.clear(); + all_traversals.clear(); + exp_agenda.clear(); + agenda.clear(); + tps2node.clear(); + edge2node.clear(); + all_edges.clear(); + passive_edges.clear(); + active_edges.clear(); + } + + ~EarleyComposerImpl() { + FreeAll(); + } + + // returns the total number of edges created during composition + int EdgesCreated() const { + return free_list_.size(); + } + + private: + void DoScan(const Edge* edge) { + // here, we assume that the FST will potentially have many more outgoing + // edges than the grammar, which will be just a couple. If you want to + // efficiently handle the case where both are relatively large, this code + // will need to change how the intersection is done. The best general + // solution would probably be the Baeza-Yates double binary search. + + const EGrammarNode* dot = edge->dot; + const FSTNode* r = edge->r; + const map& terms = dot->GetTerminals(); + for (map::const_iterator git = terms.begin(); + git != terms.end(); ++git) { + const FSTNode* next_r = r->Extend(git->first); + if (!next_r) continue; + const EGrammarNode* next_dot = &git->second; + const bool grammar_continues = next_dot->GrammarContinues(); + const bool rule_completes = next_dot->RuleCompletes(); + assert(grammar_continues || rule_completes); + const SparseVector& input_features = next_dot->GetCFGProductionFeatures(); + // create up to 4 new edges! + if (next_r->HasOutgoingNonEpsilonEdges()) { // are there further symbols in the FST? + const TargetPhraseSet* translations = NULL; + if (rule_completes) + IncorporateNewEdge(new Edge(edge->cat, next_dot, edge->q, next_r, edge, translations, input_features)); + if (grammar_continues) + IncorporateNewEdge(new Edge(edge->cat, next_dot, edge->q, next_r, edge, translations)); + } + if (next_r->HasData()) { // indicates a loop back to q_0 in the FST + const TargetPhraseSet* translations = next_r->GetTranslations(); + if (rule_completes) + IncorporateNewEdge(new Edge(edge->cat, next_dot, edge->q, q_0_, edge, translations, input_features)); + if (grammar_continues) + IncorporateNewEdge(new Edge(edge->cat, next_dot, edge->q, q_0_, edge, translations)); + } + } + } + + void DoPredict(const Edge* edge, const EGrammar& g) { + const EGrammarNode* dot = edge->dot; + const map& non_terms = dot->GetNonTerminals(); + for (map::const_iterator git = non_terms.begin(); + git != non_terms.end(); ++git) { + const WordID nt_to_predict = git->first; + //cerr << edge->id << " -- " << TD::Convert(nt_to_predict*-1) << endl; + EGrammar::const_iterator egi = g.find(nt_to_predict); + if (egi == g.end()) { + cerr << "[ERROR] Can't find any grammar rules with a LHS of type " + << TD::Convert(-1*nt_to_predict) << '!' << endl; + continue; + } + assert(edge->IsActive()); + const EGrammarNode* new_dot = &egi->second; + Edge* new_edge = new Edge(nt_to_predict, new_dot, edge->r, edge); + IncorporateNewEdge(new_edge); + } + } + + void DoComplete(const Edge* passive) { +#ifdef DEBUG_CHART_PARSER + cerr << " complete: " << *passive << endl; +#endif + const WordID completed_nt = passive->cat; + const FSTNode* q = passive->q; + const FSTNode* next_r = passive->r; + const Edge query(q); + const pair::iterator, + unordered_multiset::iterator > p = + active_edges.equal_range(&query); + for (unordered_multiset::iterator it = p.first; + it != p.second; ++it) { + const Edge* active = *it; +#ifdef DEBUG_CHART_PARSER + cerr << " pos: " << *active << endl; +#endif + const EGrammarNode* next_dot = active->dot->Extend(completed_nt); + if (!next_dot) continue; + const SparseVector& input_features = next_dot->GetCFGProductionFeatures(); + // add up to 2 rules + if (next_dot->RuleCompletes()) + IncorporateNewEdge(new Edge(active->cat, next_dot, active->q, next_r, active, passive, input_features)); + if (next_dot->GrammarContinues()) + IncorporateNewEdge(new Edge(active->cat, next_dot, active->q, next_r, active, passive)); + } + } + + void DoMergeWithPassives(const Edge* active) { + // edge is active, has non-terminals, we need to find the passives that can extend it + assert(active->IsActive()); + assert(active->dot->HasNonTerminals()); +#ifdef DEBUG_CHART_PARSER + cerr << " merge active with passives: ACT=" << *active << endl; +#endif + const Edge query(active->r, 1); + const pair::iterator, + unordered_multiset::iterator > p = + passive_edges.equal_range(&query); + for (unordered_multiset::iterator it = p.first; + it != p.second; ++it) { + const Edge* passive = *it; + const EGrammarNode* next_dot = active->dot->Extend(passive->cat); + if (!next_dot) continue; + const FSTNode* next_r = passive->r; + const SparseVector& input_features = next_dot->GetCFGProductionFeatures(); + if (next_dot->RuleCompletes()) + IncorporateNewEdge(new Edge(active->cat, next_dot, active->q, next_r, active, passive, input_features)); + if (next_dot->GrammarContinues()) + IncorporateNewEdge(new Edge(active->cat, next_dot, active->q, next_r, active, passive)); + } + } + + // take ownership of edge memory, add to various indexes, etc + // returns true if this edge is new + bool IncorporateNewEdge(Edge* edge) { + free_list_.push_back(edge); + if (edge->passive_parent && edge->active_parent) { + Traversal* t = new Traversal(edge, edge->active_parent, edge->passive_parent); + traversal_free_list_.push_back(t); + if (all_traversals.find(t) != all_traversals.end()) { + return false; + } else { + all_traversals.insert(t); + } + } + exp_agenda.AddEdge(edge); + return true; + } + + bool FinishEdge(const Edge* edge, Hypergraph* hg) { + bool is_new = false; + if (all_edges.find(edge) == all_edges.end()) { +#ifdef DEBUG_CHART_PARSER + cerr << *edge << " is NEW\n"; +#endif + all_edges.insert(edge); + is_new = true; + if (edge->IsPassive()) passive_edges.insert(edge); + if (edge->IsActive()) active_edges.insert(edge); + agenda.AddEdge(edge); + } else { +#ifdef DEBUG_CHART_PARSER + cerr << *edge << " is NOT NEW.\n"; +#endif + } + AddEdgeToTranslationForest(edge, hg); + return is_new; + } + + // build the translation forest + void AddEdgeToTranslationForest(const Edge* edge, Hypergraph* hg) { + assert(hg->nodes_.size() < kMAX_NODES); + Hypergraph::Node* tps = NULL; + // first add any target language rules + if (edge->tps) { + Hypergraph::Node*& node = tps2node[(size_t)edge->tps]; + if (!node) { + // cerr << "Creating phrases for " << edge->tps << endl; + const vector& rules = edge->tps->GetRules(); + node = hg->AddNode(kPHRASE); + for (int i = 0; i < rules.size(); ++i) { + Hypergraph::Edge* hg_edge = hg->AddEdge(rules[i], Hypergraph::TailNodeVector()); + hg_edge->feature_values_ += rules[i]->GetFeatureValues(); + hg->ConnectEdgeToHeadNode(hg_edge, node); + } + } + tps = node; + } + Hypergraph::Node*& head_node = edge2node[edge]; + if (!head_node) + head_node = hg->AddNode(edge->cat); + if (edge->cat == start_cat_ && edge->q == q_0_ && edge->r == q_0_ && edge->IsPassive()) { + assert(goal_node == NULL || goal_node == head_node); + goal_node = head_node; + } + int rhs1 = 0; + int rhs2 = 0; + Hypergraph::TailNodeVector tail; + SparseVector extra; + if (edge->IsCreatedByPredict()) { + // extra.set_value(FD::Convert("predict"), 1); + } else if (edge->IsCreatedByScan()) { + tail.push_back(edge2node[edge->active_parent]->id_); + rhs1 = edge->active_parent->cat; + if (tps) { + tail.push_back(tps->id_); + rhs2 = kPHRASE; + } + //extra.set_value(FD::Convert("scan"), 1); + } else if (edge->IsCreatedByComplete()) { + tail.push_back(edge2node[edge->active_parent]->id_); + rhs1 = edge->active_parent->cat; + tail.push_back(edge2node[edge->passive_parent]->id_); + rhs2 = edge->passive_parent->cat; + //extra.set_value(FD::Convert("complete"), 1); + } else { + assert(!"unexpected edge type!"); + } + //cerr << head_node->id_ << "<--" << *edge << endl; + +#ifdef DEBUG_CHART_PARSER + for (int i = 0; i < tail.size(); ++i) + if (tail[i] == head_node->id_) { + cerr << "ERROR: " << *edge << "\n i=" << i << endl; + if (i == 1) { cerr << "\tP: " << *edge->passive_parent << endl; } + if (i == 0) { cerr << "\tA: " << *edge->active_parent << endl; } + assert(!"self-loop found!"); + } +#endif + Hypergraph::Edge* hg_edge = NULL; + if (tail.size() == 0) { + hg_edge = hg->AddEdge(CreateEpsilonRule(edge->cat), tail); + } else if (tail.size() == 1) { + hg_edge = hg->AddEdge(CreateUnaryRule(edge->cat, rhs1), tail); + } else if (tail.size() == 2) { + hg_edge = hg->AddEdge(CreateBinaryRule(edge->cat, rhs1, rhs2), tail); + } + if (edge->features) + hg_edge->feature_values_ += *edge->features; + hg_edge->feature_values_ += extra; + hg->ConnectEdgeToHeadNode(hg_edge, head_node); + } + + Hypergraph::Node* goal_node; + EdgeQueue exp_agenda; + EdgeQueue agenda; + unordered_map tps2node; + unordered_map edge2node; + unordered_set all_traversals; + unordered_set all_edges; + unordered_multiset passive_edges; + unordered_multiset active_edges; + vector free_list_; + vector traversal_free_list_; + const WordID start_cat_; + const FSTNode* const q_0_; +}; + +#ifdef DEBUG_CHART_PARSER +static string TrimRule(const string& r) { + size_t start = r.find(" |||") + 5; + size_t end = r.rfind(" |||"); + return r.substr(start, end - start); +} +#endif + +void AddGrammarRule(const string& r, EGrammar* g) { + const size_t pos = r.find(" ||| "); + if (pos == string::npos || r[0] != '[') { + cerr << "Bad rule: " << r << endl; + return; + } + const size_t rpos = r.rfind(" ||| "); + string feats; + string rs = r; + if (rpos != pos) { + feats = r.substr(rpos + 5); + rs = r.substr(0, rpos); + } + string rhs = rs.substr(pos + 5); + string trule = rs + " ||| " + rhs + " ||| " + feats; + TRule tr(trule); +#ifdef DEBUG_CHART_PARSER + string hint_last_rule; +#endif + EGrammarNode* cur = &(*g)[tr.GetLHS()]; + cur->is_root = true; + for (int i = 0; i < tr.FLength(); ++i) { + WordID sym = tr.f()[i]; +#ifdef DEBUG_CHART_PARSER + hint_last_rule = TD::Convert(sym < 0 ? -sym : sym); + cur->hint += " <@@> (*" + hint_last_rule + ") " + TrimRule(tr.AsString()); +#endif + if (sym < 0) + cur = &cur->ntptr[sym]; + else + cur = &cur->tptr[sym]; + } +#ifdef DEBUG_CHART_PARSER + cur->hint += " <@@> (" + hint_last_rule + "*) " + TrimRule(tr.AsString()); +#endif + cur->is_some_rule_complete = true; + cur->input_features = tr.GetFeatureValues(); +} + +EarleyComposer::~EarleyComposer() { + delete pimpl_; +} + +EarleyComposer::EarleyComposer(const FSTNode* fst) { + InitializeConstants(); + pimpl_ = new EarleyComposerImpl(kUNIQUE_START, *fst); +} + +bool EarleyComposer::Compose(const Hypergraph& src_forest, Hypergraph* trg_forest) { + // first, convert the src forest into an EGrammar + EGrammar g; + const int nedges = src_forest.edges_.size(); + const int nnodes = src_forest.nodes_.size(); + vector cats(nnodes); + bool assign_cats = false; + for (int i = 0; i < nnodes; ++i) + if (assign_cats) { + cats[i] = TD::Convert("CAT_" + boost::lexical_cast(i)) * -1; + } else { + cats[i] = src_forest.nodes_[i].cat_; + } + // construct the grammar + for (int i = 0; i < nedges; ++i) { + const Hypergraph::Edge& edge = src_forest.edges_[i]; + const vector& src = edge.rule_->f(); + EGrammarNode* cur = &g[cats[edge.head_node_]]; + cur->is_root = true; + int ntc = 0; + for (int j = 0; j < src.size(); ++j) { + WordID sym = src[j]; + if (sym <= 0) { + sym = cats[edge.tail_nodes_[ntc]]; + ++ntc; + cur = &cur->ntptr[sym]; + } else { + cur = &cur->tptr[sym]; + } + } + cur->is_some_rule_complete = true; + cur->input_features = edge.feature_values_; + } + EGrammarNode& goal_rule = g[kUNIQUE_START]; + assert((goal_rule.ntptr.size() == 1 && goal_rule.tptr.size() == 0) || + (goal_rule.ntptr.size() == 0 && goal_rule.tptr.size() == 1)); + + return pimpl_->Compose(g, trg_forest); +} + +bool EarleyComposer::Compose(istream* in, Hypergraph* trg_forest) { + EGrammar g; + while(*in) { + string line; + getline(*in, line); + if (line.empty()) continue; + AddGrammarRule(line, &g); + } + + return pimpl_->Compose(g, trg_forest); +} diff --git a/decoder/earley_composer.h b/decoder/earley_composer.h new file mode 100644 index 000000000..31602f675 --- /dev/null +++ b/decoder/earley_composer.h @@ -0,0 +1,29 @@ +#ifndef EARLEY_COMPOSER_H_ +#define EARLEY_COMPOSER_H_ + +#include + +class EarleyComposerImpl; +class FSTNode; +class Hypergraph; + +class EarleyComposer { + public: + ~EarleyComposer(); + EarleyComposer(const FSTNode* phrasetable_root); + bool Compose(const Hypergraph& src_forest, Hypergraph* trg_forest); + + // reads the grammar from a file. There must be a single top-level + // S -> X rule. Anything else is possible. Format is: + // [S] ||| [SS,1] + // [SS] ||| [NP,1] [VP,2] ||| Feature1=0.2 Feature2=-2.3 + // [SS] ||| [VP,1] [NP,2] ||| Feature1=0.8 + // [NP] ||| [DET,1] [N,2] ||| Feature3=2 + // ... + bool Compose(std::istream* grammar_file, Hypergraph* trg_forest); + + private: + EarleyComposerImpl* pimpl_; +}; + +#endif diff --git a/decoder/factored_lexicon_helper.cc b/decoder/factored_lexicon_helper.cc new file mode 100644 index 000000000..e78992156 --- /dev/null +++ b/decoder/factored_lexicon_helper.cc @@ -0,0 +1,81 @@ +#include "factored_lexicon_helper.h" + +#include "filelib.h" +#include "stringlib.h" +#include "sentence_metadata.h" + +using namespace std; + +FactoredLexiconHelper::FactoredLexiconHelper() : + kNULL(TD::Convert("")), + has_src_(false), + has_trg_(false) { InitEscape(); } + +FactoredLexiconHelper::FactoredLexiconHelper(const std::string& srcfile, const std::string& trgmapfile) : + kNULL(TD::Convert("")), + has_src_(false), + has_trg_(false) { + if (srcfile.size() && srcfile != "*") { + ReadFile rf(srcfile); + has_src_ = true; + istream& in = *rf.stream(); + string line; + while(in) { + getline(in, line); + if (!in) continue; + vector v; + TD::ConvertSentence(line, &v); + src_.push_back(v); + } + } + if (trgmapfile.size() && trgmapfile != "*") { + ReadFile rf(trgmapfile); + has_trg_ = true; + istream& in = *rf.stream(); + string line; + vector v; + while(in) { + getline(in, line); + if (!in) continue; + SplitOnWhitespace(line, &v); + if (v.size() != 2) { + cerr << "Error reading line in map file: " << line << endl; + abort(); + } + WordID& to = trgmap_[TD::Convert(v[0])]; + if (to != 0) { + cerr << "Duplicate entry for word " << v[0] << endl; + abort(); + } + to = TD::Convert(v[1]); + } + } + InitEscape(); +} + +void FactoredLexiconHelper::InitEscape() { + escape_[TD::Convert("=")] = TD::Convert("__EQ"); + escape_[TD::Convert(";")] = TD::Convert("__SC"); + escape_[TD::Convert(",")] = TD::Convert("__CO"); +} + +void FactoredLexiconHelper::PrepareForInput(const SentenceMetadata& smeta) { + if (has_src_) { + const int id = smeta.GetSentenceID(); + assert(id < src_.size()); + cur_src_ = src_[id]; + } else { + cur_src_.resize(smeta.GetSourceLength()); + for (int i = 0; i < cur_src_.size(); ++i) { + const vector& arcs = smeta.GetSourceLattice()[i]; + assert(arcs.size() == 1); // only sentences supported for now + cur_src_[i] = arcs[0].label; + } + } + if (cur_src_.size() != smeta.GetSourceLength()) { + cerr << "Length mismatch between mapped source and real source in sentence id=" << smeta.GetSentenceID() << endl; + cerr << " mapped len=" << cur_src_.size() << endl; + cerr << " actual len=" << smeta.GetSourceLength() << endl; + } +} + diff --git a/decoder/factored_lexicon_helper.h b/decoder/factored_lexicon_helper.h new file mode 100644 index 000000000..8e89f4736 --- /dev/null +++ b/decoder/factored_lexicon_helper.h @@ -0,0 +1,67 @@ +#ifndef FACTORED_LEXICON_HELPER_ +#define FACTORED_LEXICON_HELPER_ + +#include +#include +#include +#include +#include "tdict.h" + +class SentenceMetadata; + +// when computing features, it can be advantageous to: +// 1) back off to less specific forms (e.g., less highly inflected forms, POS tags, etc) +// 2) look at more specific forms (on the source ONLY) +// this class helps you do both by creating a "corpus" view +// should probably add a discussion of why the source can be "refined" by this class +// but not the target. basically, this is because the source is on the right side of +// the conditioning line in the model, and the target is on the left. the most specific +// form must always be generated, but the "source" can include arbitrarily large +// context. +// this currently only works for sentence input to maintain simplicity of the code and +// file formats, but there is no reason why it couldn't work with lattices / CFGs +class FactoredLexiconHelper { + public: + // default constructor does no mapping + FactoredLexiconHelper(); + // Either filename can be empty or * to indicate no mapping + FactoredLexiconHelper(const std::string& srcfile, const std::string& trgmapfile); + + void PrepareForInput(const SentenceMetadata& smeta); + + inline WordID SourceWordAtPosition(const int i) const { + if (i < 0) return kNULL; + assert(i < cur_src_.size()); + return Escape(cur_src_[i]); + } + + inline WordID CoarsenedTargetWordForTarget(const WordID surface_target) const { + if (has_trg_) { + const WordWordMap::const_iterator it = trgmap_.find(surface_target); + if (it == trgmap_.end()) return surface_target; + return Escape(it->second); + } else { + return Escape(surface_target); + } + } + + private: + inline WordID Escape(WordID word) const { + const std::map::const_iterator it = escape_.find(word); + if (it == escape_.end()) return word; + return it->second; + } + + void InitEscape(); + + const WordID kNULL; + bool has_src_; + bool has_trg_; + std::vector > src_; + typedef std::map WordWordMap; + WordWordMap trgmap_; + std::vector cur_src_; + std::map escape_; +}; + +#endif diff --git a/decoder/ff.cc b/decoder/ff.cc new file mode 100644 index 000000000..a6a035b5d --- /dev/null +++ b/decoder/ff.cc @@ -0,0 +1,38 @@ +#include "ff.h" + +#include "tdict.h" +#include "hg.h" + +using namespace std; + +FeatureFunction::~FeatureFunction() {} + +void FeatureFunction::PrepareForInput(const SentenceMetadata&) {} + +void FeatureFunction::FinalTraversalFeatures(const void* /* ant_state */, + SparseVector* /* features */) const {} + +string FeatureFunction::usage_helper(std::string const& name,std::string const& params,std::string const& details,bool sp,bool sd) { + string r=name; + if (sp) { + r+=": "; + r+=params; + } + if (sd) { + r+="\n"; + r+=details; + } + return r; +} + +void FeatureFunction::TraversalFeaturesImpl(const SentenceMetadata&, + const Hypergraph::Edge&, + const std::vector&, + SparseVector*, + SparseVector*, + void*) const { + cerr << "TraversalFeaturesImpl not implemented - override it or TraversalFeaturesLog\n"; + abort(); +} + + diff --git a/decoder/ff.h b/decoder/ff.h new file mode 100644 index 000000000..d6487d970 --- /dev/null +++ b/decoder/ff.h @@ -0,0 +1,100 @@ +#ifndef FF_H_ +#define FF_H_ + +#include +#include +#include "sparse_vector.h" + +namespace HG { struct Edge; struct Node; } +class Hypergraph; +class SentenceMetadata; + +// if you want to develop a new feature, inherit from this class and +// override TraversalFeaturesImpl(...). If it's a feature that returns / +// depends on context, you may also need to implement +// FinalTraversalFeatures(...) +class FeatureFunction { + friend class ExternalFeature; + public: + std::string name_; // set by FF factory using usage() + FeatureFunction() : state_size_(), ignored_state_size_() {} + explicit FeatureFunction(int state_size, int ignored_state_size = 0) + : state_size_(state_size), ignored_state_size_(ignored_state_size) {} + virtual ~FeatureFunction(); + bool IsStateful() const { return state_size_ > 0; } + int StateSize() const { return state_size_; } + // Returns the number of bytes in the state that should be ignored during + // search. When non-zero, the last N bytes in the state should be ignored when + // splitting a hypernode by the state. This allows the feature function to + // store some side data and later retrieve it via the state bytes. + // + // In general, this should not be necessary and it should always be possible + // to replace this with a more appropriate design of state (if you find + // yourself having to ignore some part of the state, you are most likely + // storing redundant information in the state). Be sure that you + // understand how this affects ApplyModelSet() before using it. + int IgnoredStateSize() const { return ignored_state_size_; } + + // override this. not virtual because we want to expose this to factory template for help before creating a FF + static std::string usage(bool show_params,bool show_details) { + return usage_helper("FIXME_feature_needs_name","[no parameters]","[no documentation yet]",show_params,show_details); + } + static std::string usage_helper(std::string const& name,std::string const& params,std::string const& details,bool show_params,bool show_details); + + // called once, per input, before any feature calls to TraversalFeatures, etc. + // used to initialize sentence-specific data structures + virtual void PrepareForInput(const SentenceMetadata& smeta); + + // Compute the feature values and (if this applies) the estimates of the + // feature values when this edge is used incorporated into a larger context + inline void TraversalFeatures(const SentenceMetadata& smeta, + const HG::Edge& edge, + const std::vector& ant_contexts, + SparseVector* features, + SparseVector* estimated_features, + void* out_state) const { + TraversalFeaturesImpl(smeta, edge, ant_contexts, + features, estimated_features, out_state); + // TODO it's easy for careless feature function developers to overwrite + // the end of their state and clobber someone else's memory. These bugs + // will be horrendously painful to track down. There should be some + // optional strict mode that's enforced here that adds some kind of + // barrier between the blocks reserved for the residual contexts + } + + // if there's some state left when you transition to the goal state, score + // it here. For example, a language model might the cost of adding + // and . + virtual void FinalTraversalFeatures(const void* residual_state, + SparseVector* final_features) const; + + protected: + // context is a pointer to a buffer of size NumBytesContext() that the + // feature function can write its state to. It's up to the feature function + // to determine how much space it needs and to determine how to encode its + // residual contextual information since it is OPAQUE to all clients outside + // of the particular FeatureFunction class. There is one exception: + // equality of the contents (i.e., memcmp) is required to determine whether + // two states can be combined. + virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta, + const HG::Edge& edge, + const std::vector& ant_contexts, + SparseVector* features, + SparseVector* estimated_features, + void* context) const; + + // !!! ONLY call these from subclass *CONSTRUCTORS* !!! + void SetStateSize(size_t state_size) { + state_size_ = state_size; + } + + // See document of IgnoredStateSize() above. + void SetIgnoredStateSize(size_t ignored_state_size) { + ignored_state_size_ = ignored_state_size; + } + + private: + int state_size_, ignored_state_size_; +}; + +#endif diff --git a/decoder/ff_basic.cc b/decoder/ff_basic.cc new file mode 100644 index 000000000..f960418aa --- /dev/null +++ b/decoder/ff_basic.cc @@ -0,0 +1,79 @@ +#include "ff_basic.h" + +#include "fast_lexical_cast.hpp" +#include "hg.h" + +using namespace std; + +// Hiero and Joshua use log_10(e) as the value, so I do to +WordPenalty::WordPenalty(const string& param) : + fid_(FD::Convert("WordPenalty")), + value_(-1.0 / log(10)) { + if (!param.empty()) { + cerr << "Warning WordPenalty ignoring parameter: " << param << endl; + } +} + +void WordPenalty::TraversalFeaturesImpl(const SentenceMetadata& smeta, + const Hypergraph::Edge& edge, + const std::vector& ant_states, + SparseVector* features, + SparseVector* estimated_features, + void* state) const { + (void) smeta; + (void) ant_states; + (void) state; + (void) estimated_features; + features->set_value(fid_, edge.rule_->EWords() * value_); +} + + +SourceWordPenalty::SourceWordPenalty(const string& param) : + fid_(FD::Convert("SourceWordPenalty")), + value_(-1.0 / log(10)) { + if (!param.empty()) { + cerr << "Warning SourceWordPenalty ignoring parameter: " << param << endl; + } +} + +void SourceWordPenalty::TraversalFeaturesImpl(const SentenceMetadata& smeta, + const Hypergraph::Edge& edge, + const std::vector& ant_states, + SparseVector* features, + SparseVector* estimated_features, + void* state) const { + (void) smeta; + (void) ant_states; + (void) state; + (void) estimated_features; + features->set_value(fid_, edge.rule_->FWords() * value_); +} + +ArityPenalty::ArityPenalty(const std::string& param) { + string fname = "Arity_"; + unsigned MAX=DEFAULT_MAX_ARITY; + using namespace boost; + if (!param.empty()) + MAX=lexical_cast(param); + for (unsigned i = 0; i <= MAX; ++i) { + WordID fid=FD::Convert(fname+lexical_cast(i)); + fids_.push_back(fid); + } + // pretty up features vector in case FD was frozen. doesn't change anything + while (!fids_.empty() && fids_.back()==0) fids_.pop_back(); +} + +void ArityPenalty::TraversalFeaturesImpl(const SentenceMetadata& smeta, + const Hypergraph::Edge& edge, + const std::vector& ant_states, + SparseVector* features, + SparseVector* estimated_features, + void* state) const { + (void) smeta; + (void) ant_states; + (void) state; + (void) estimated_features; + unsigned a=edge.Arity(); + if (a < fids_.size()) features->set_value(fids_[a], 1.0); +} + diff --git a/decoder/ff_basic.h b/decoder/ff_basic.h new file mode 100644 index 000000000..c63daf0fb --- /dev/null +++ b/decoder/ff_basic.h @@ -0,0 +1,67 @@ +#ifndef FF_BASIC_H_ +#define FF_BASIC_H_ + +#include "ff.h" + +// word penalty feature, for each word on the E side of a rule, +// add value_ +class WordPenalty : public FeatureFunction { + public: + WordPenalty(const std::string& param); + static std::string usage(bool p,bool d) { + return usage_helper("WordPenalty","","number of target words (local feature)",p,d); + } + protected: + virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta, + const HG::Edge& edge, + const std::vector& ant_contexts, + SparseVector* features, + SparseVector* estimated_features, + void* context) const; + private: + const int fid_; + const double value_; +}; + +class SourceWordPenalty : public FeatureFunction { + public: + SourceWordPenalty(const std::string& param); + static std::string usage(bool p,bool d) { + return usage_helper("SourceWordPenalty","","number of source words (local feature, and meaningless except when input has non-constant number of source words, e.g. segmentation/morphology/speech recognition lattice)",p,d); + } + protected: + virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta, + const HG::Edge& edge, + const std::vector& ant_contexts, + SparseVector* features, + SparseVector* estimated_features, + void* context) const; + private: + const int fid_; + const double value_; +}; + +#define DEFAULT_MAX_ARITY 50 +#define DEFAULT_MAX_ARITY_STRINGIZE(x) #x +#define DEFAULT_MAX_ARITY_STRINGIZE_EVAL(x) DEFAULT_MAX_ARITY_STRINGIZE(x) +#define DEFAULT_MAX_ARITY_STR DEFAULT_MAX_ARITY_STRINGIZE_EVAL(DEFAULT_MAX_ARITY) + +class ArityPenalty : public FeatureFunction { + public: + ArityPenalty(const std::string& param); + static std::string usage(bool p,bool d) { + return usage_helper("ArityPenalty","[MaxArity(default " DEFAULT_MAX_ARITY_STR ")]","Indicator feature Arity_N=1 for rule of arity N (local feature). 0<=N<=MaxArity(default " DEFAULT_MAX_ARITY_STR ")",p,d); + } + + protected: + virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta, + const HG::Edge& edge, + const std::vector& ant_contexts, + SparseVector* features, + SparseVector* estimated_features, + void* context) const; + private: + std::vector fids_; +}; + +#endif diff --git a/decoder/ff_bleu.cc b/decoder/ff_bleu.cc new file mode 100644 index 000000000..a842bba80 --- /dev/null +++ b/decoder/ff_bleu.cc @@ -0,0 +1,289 @@ +namespace { +char const* bleu_usage_name="BLEUModel"; +char const* bleu_usage_short="[-o 3|4]"; +char const* bleu_usage_verbose="Uses feature id 0! Make sure there are no other features whose weights aren't specified or there may be conflicts. Computes oracle with weighted combination of BLEU and model score (from previous model set, using weights on edges?). Performs ngram context expansion; expect reference translation info in sentence metadata; if document scorer is IBM_BLEU_3, then use order 3; otherwise use order 4."; +} + + +#include +#include +#include "fast_lexical_cast.hpp" + +#include + +#include "ff_bleu.h" +#include "tdict.h" +#include "hg.h" +#include "stringlib.h" +#include "sentence_metadata.h" +#include "scorer.h" + +using namespace std; + +class BLEUModelImpl { + public: + explicit BLEUModelImpl(int order) : + buffer_(), order_(order), state_size_(OrderToStateSize(order) - 1), + floor_(-100.0), + kSTART(TD::Convert("")), + kSTOP(TD::Convert("")), + kUNKNOWN(TD::Convert("")), + kNONE(-1), + kSTAR(TD::Convert("<{STAR}>")) {} + + virtual ~BLEUModelImpl() { + } + + inline int StateSize(const void* state) const { + return *(static_cast(state) + state_size_); + } + + inline void SetStateSize(int size, void* state) const { + *(static_cast(state) + state_size_) = size; + } + + void GetRefToNgram() + {} + + string DebugStateToString(const void* state) const { + int len = StateSize(state); + const int* astate = reinterpret_cast(state); + string res = "["; + for (int i = 0; i < len; ++i) { + res += " "; + res += TD::Convert(astate[i]); + } + res += " ]"; + return res; + } + + inline double ProbNoRemnant(int i, int len) { + int edge = len; + bool flag = true; + double sum = 0.0; + while (i >= 0) { + if (buffer_[i] == kSTAR) { + edge = i; + flag = false; + } else if (buffer_[i] <= 0) { + edge = i; + flag = true; + } else { + if ((edge-i >= order_) || (flag && !(i == (len-1) && buffer_[i] == kSTART))) + { //sum += LookupProbForBufferContents(i); + //cerr << "FT"; + CalcPhrase(buffer_[i], &buffer_[i+1]); + } + } + --i; + } + return sum; + } + + double FinalTraversalCost(const void* state) { + int slen = StateSize(state); + int len = slen + 2; + // cerr << "residual len: " << len << endl; + buffer_.resize(len + 1); + buffer_[len] = kNONE; + buffer_[len-1] = kSTART; + const int* astate = reinterpret_cast(state); + int i = len - 2; + for (int j = 0; j < slen; ++j,--i) + buffer_[i] = astate[j]; + buffer_[i] = kSTOP; + assert(i == 0); + return ProbNoRemnant(len - 1, len); + } + + vector CalcPhrase(int word, int* context) { + int i = order_; + vector vs; + int c = 1; + vs.push_back(word); + // while (i > 1 && *context > 0) { + while (*context > 0) { + --i; + vs.push_back(*context); + ++context; + ++c; + } + if(false){ cerr << "VS1( "; + vector::reverse_iterator rit; + for ( rit=vs.rbegin() ; rit != vs.rend(); ++rit ) + cerr << " " << TD::Convert(*rit); + cerr << ")\n";} + + return vs; + } + + + double LookupWords(const TRule& rule, const vector& ant_states, void* vstate, const SentenceMetadata& smeta) { + + int len = rule.ELength() - rule.Arity(); + + for (int i = 0; i < ant_states.size(); ++i) + len += StateSize(ant_states[i]); + buffer_.resize(len + 1); + buffer_[len] = kNONE; + int i = len - 1; + const vector& e = rule.e(); + + /*cerr << "RULE::" << rule.ELength() << " "; + for (vector::const_iterator i = e.begin(); i != e.end(); ++i) + { + const WordID& c = *i; + if(c > 0) cerr << TD::Convert(c) << "--"; + else cerr <<"N--"; + } + cerr << endl; + */ + + for (int j = 0; j < e.size(); ++j) { + if (e[j] < 1) { + const int* astate = reinterpret_cast(ant_states[-e[j]]); + int slen = StateSize(astate); + for (int k = 0; k < slen; ++k) + buffer_[i--] = astate[k]; + } else { + buffer_[i--] = e[j]; + } + } + + double approx_bleu = 0.0; + int* remnant = reinterpret_cast(vstate); + int j = 0; + i = len - 1; + int edge = len; + + + vector vs; + while (i >= 0) { + vs = CalcPhrase(buffer_[i],&buffer_[i+1]); + if (buffer_[i] == kSTAR) { + edge = i; + } else if (edge-i >= order_) { + + vs = CalcPhrase(buffer_[i],&buffer_[i+1]); + + } else if (edge == len && remnant) { + remnant[j++] = buffer_[i]; + } + --i; + } + + //calculate Bvector here + /* cerr << "VS1( "; + vector::reverse_iterator rit; + for ( rit=vs.rbegin() ; rit != vs.rend(); ++rit ) + cerr << " " << TD::Convert(*rit); + cerr << ")\n"; + */ + + ScoreP node_score_p = smeta.GetDocScorer()[smeta.GetSentenceID()]->ScoreCCandidate(vs); + Score *node_score=node_score_p.get(); + string details; + node_score->ScoreDetails(&details); + const Score *base_score= &smeta.GetScore(); + //cerr << "SWBASE : " << base_score->ComputeScore() << details << " "; + + int src_length = smeta.GetSourceLength(); + node_score->PlusPartialEquals(*base_score, rule.EWords(), rule.FWords(), src_length ); + float oracledoc_factor = (src_length + smeta.GetDocLen())/ src_length; + + //how it seems to be done in code + //TODO: might need to reverse the -1/+1 of the oracle/neg examples + //TO VLADIMIR: the polarity would be reversed if you switched error (1-BLEU) for BLEU. + approx_bleu = ( rule.FWords() * oracledoc_factor ) * node_score->ComputeScore(); + //how I thought it was done from the paper + //approx_bleu = ( rule.FWords()+ smeta.GetDocLen() ) * node_score->ComputeScore(); + + if (!remnant){ return approx_bleu;} + + if (edge != len || len >= order_) { + remnant[j++] = kSTAR; + if (order_-1 < edge) edge = order_-1; + for (int i = edge-1; i >= 0; --i) + remnant[j++] = buffer_[i]; + } + + SetStateSize(j, vstate); + //cerr << "Return APPROX_BLEU: " << approx_bleu << " "<< DebugStateToString(vstate) << endl; + return approx_bleu; + } + + static int OrderToStateSize(int order) { + return ((order-1) * 2 + 1) * sizeof(WordID) + 1; + } + + protected: + vector buffer_; + const int order_; + const int state_size_; + const double floor_; + + public: + const WordID kSTART; + const WordID kSTOP; + const WordID kUNKNOWN; + const WordID kNONE; + const WordID kSTAR; +}; + +string BLEUModel::usage(bool param,bool verbose) { + return usage_helper(bleu_usage_name,bleu_usage_short,bleu_usage_verbose,param,verbose); +} + +BLEUModel::BLEUModel(const string& param) : + fid_(0) { //The partial BLEU score is kept in feature id=0 + vector argv; + int argc = SplitOnWhitespace(param, &argv); + int order = 3; + + //loop over argv and load all references into vector of NgramMaps + if (argc >= 1) { + if (argv[0] != "-o" || argc<2) { + cerr<(argv[1]); + } + + SetStateSize(BLEUModelImpl::OrderToStateSize(order)); + pimpl_ = new BLEUModelImpl(order); +} + +BLEUModel::~BLEUModel() { + delete pimpl_; +} + +string BLEUModel::DebugStateToString(const void* state) const{ + return pimpl_->DebugStateToString(state); +} + +void BLEUModel::TraversalFeaturesImpl(const SentenceMetadata& smeta, + const Hypergraph::Edge& edge, + const vector& ant_states, + SparseVector* features, + SparseVector* /* estimated_features */, + void* state) const { + + (void) smeta; + /*cerr << "In BM calling set " << endl; + const Score *s= &smeta.GetScore(); + const int dl = smeta.GetDocLen(); + cerr << "SCO " << s->ComputeScore() << endl; + const DocScorer *ds = &smeta.GetDocScorer(); + */ + +// cerr<< "ff_bleu loading sentence " << smeta.GetSentenceID() << endl; + //} + features->set_value(fid_, pimpl_->LookupWords(*edge.rule_, ant_states, state, smeta)); + //cerr << "FID" << fid_ << " " << DebugStateToString(state) << endl; +} + +void BLEUModel::FinalTraversalFeatures(const void* ant_state, + SparseVector* features) const { + + features->set_value(fid_, pimpl_->FinalTraversalCost(ant_state)); +} diff --git a/decoder/ff_bleu.h b/decoder/ff_bleu.h new file mode 100644 index 000000000..8ca2c0958 --- /dev/null +++ b/decoder/ff_bleu.h @@ -0,0 +1,32 @@ +#ifndef BLEU_FF_H_ +#define BLEU_FF_H_ + +#include +#include + +#include "hg.h" +#include "ff.h" + +class BLEUModelImpl; + +class BLEUModel : public FeatureFunction { + public: + // param = "filename.lm [-o n]" + BLEUModel(const std::string& param); + ~BLEUModel(); + virtual void FinalTraversalFeatures(const void* context, + SparseVector* features) const; + std::string DebugStateToString(const void* state) const; + static std::string usage(bool param,bool verbose); + protected: + virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta, + const HG::Edge& edge, + const std::vector& ant_contexts, + SparseVector* features, + SparseVector* estimated_features, + void* out_context) const; + private: + const int fid_; + mutable BLEUModelImpl* pimpl_; +}; +#endif diff --git a/decoder/ff_charset.cc b/decoder/ff_charset.cc new file mode 100644 index 000000000..6429088b6 --- /dev/null +++ b/decoder/ff_charset.cc @@ -0,0 +1,44 @@ +#include "ff_charset.h" + +#include "tdict.h" +#include "hg.h" +#include "fdict.h" +#include "stringlib.h" + +using namespace std; + +NonLatinCount::NonLatinCount(const string& param) : FeatureFunction(), fid_(FD::Convert("NonLatinCount")) {} + +bool ContainsNonLatin(const string& word) { + unsigned cur = 0; + while(cur < word.size()) { + const int size = UTF8Len(word[cur]); + if (size > 1) return true; + cur += size; + } + return false; +} + +void NonLatinCount::TraversalFeaturesImpl(const SentenceMetadata& smeta, + const Hypergraph::Edge& edge, + const std::vector& ant_contexts, + SparseVector* features, + SparseVector* estimated_features, + void* context) const { + const vector& e = edge.rule_->e(); + int count = 0; + for (int i = 0; i < e.size(); ++i) { + if (e[i] > 0) { + map::iterator it = is_non_latin_.find(e[i]); + if (it == is_non_latin_.end()) { + if ((is_non_latin_[e[i]] = ContainsNonLatin(TD::Convert(e[i])))) + ++count; + } else { + if (it->second) + ++count; + } + } + } + if (count) features->set_value(fid_, count); +} + diff --git a/decoder/ff_charset.h b/decoder/ff_charset.h new file mode 100644 index 000000000..e22ece2be --- /dev/null +++ b/decoder/ff_charset.h @@ -0,0 +1,26 @@ +#ifndef FFCHARSET_H_ +#define FFCHARSET_H_ + +#include +#include +#include "ff.h" +#include "hg.h" + +class SentenceMetadata; + +class NonLatinCount : public FeatureFunction { + public: + NonLatinCount(const std::string& param); + protected: + virtual void TraversalFeaturesImpl(const SentenceMetadata& smeta, + const HG::Edge& edge, + const std::vector& ant_contexts, + SparseVector* features, + SparseVector* estimated_features, + void* context) const; + private: + mutable std::map is_non_latin_; + const int fid_; +}; + +#endif diff --git a/decoder/ff_conll.cc b/decoder/ff_conll.cc new file mode 100644 index 000000000..8ded44b79 --- /dev/null +++ b/decoder/ff_conll.cc @@ -0,0 +1,250 @@ +#include "ff_conll.h" + +#include +#include +#include +#include +#include + +#include "hg.h" +#include "filelib.h" +#include "stringlib.h" +#include "sentence_metadata.h" +#include "lattice.h" +#include "fdict.h" +#include "verbose.h" +#include "tdict.h" + +CoNLLFeatures::CoNLLFeatures(const string& param) { + // cerr << "initializing CoNLLFeatures with parameters: " << param; + kSOS = TD::Convert(""); + kEOS = TD::Convert(""); + macro_regex = sregex::compile("%([xy])\\[(-[1-9][0-9]*|0|[1-9][1-9]*)]"); + ParseArgs(param); +} + +string CoNLLFeatures::Escape(const string& x) const { + string y = x; + for (int i = 0; i < y.size(); ++i) { + if (y[i] == '=') y[i]='_'; + if (y[i] == ';') y[i]='_'; + } + return y; +} + +// replace %x[relative_location] or %y[relative_location] with actual_token +// within feature_instance +void CoNLLFeatures::ReplaceMacroWithString( + string& feature_instance, bool token_vs_label, int relative_location, + const string& actual_token) const { + + stringstream macro; + if (token_vs_label) { + macro << "%x["; + } else { + macro << "%y["; + } + macro << relative_location << "]"; + int macro_index = feature_instance.find(macro.str()); + if (macro_index == string::npos) { + cerr << "Can't find macro " << macro.str() << " in feature template " + << feature_instance; + abort(); + } + feature_instance.replace(macro_index, macro.str().size(), actual_token); +} + +void CoNLLFeatures::ReplaceTokenMacroWithString( + string& feature_instance, int relative_location, + const string& actual_token) const { + + ReplaceMacroWithString(feature_instance, true, relative_location, + actual_token); +} + +void CoNLLFeatures::ReplaceLabelMacroWithString( + string& feature_instance, int relative_location, + const string& actual_token) const { + + ReplaceMacroWithString(feature_instance, false, relative_location, + actual_token); +} + +void CoNLLFeatures::Error(const string& error_message) const { + cerr << "Error: " << error_message << "\n\n" + + << "CoNLLFeatures Usage: \n" + << " feature_function=CoNLLFeatures -t