From bc2331c783ea1722f718a5f2f7f22b1fd211d885 Mon Sep 17 00:00:00 2001 From: Dominik Thalhammer Date: Mon, 17 May 2021 21:30:00 +0200 Subject: [PATCH] Refactoring (#5) Initial code cleanups --- .editorconfig | 17 + CMakeLists.txt | 39 +- apps/cut.cpp | 6 +- apps/enroll.cpp | 10 - lib/CMakeLists.txt | 41 +- lib/audio-lib.cpp | 2 +- lib/dtw-lib.cpp | 85 ++-- lib/dtw-lib.h | 2 - lib/eavesdrop-stream.h | 1 - lib/feat-lib.cpp | 63 ++- lib/feat-lib.h | 18 +- lib/fft-stream.h | 2 - lib/framer-stream.h | 2 - lib/frontend-stream.h | 2 - lib/gain-control-stream.h | 2 - lib/intercept-stream.h | 1 - lib/matrix-wrapper.cpp | 29 +- lib/matrix-wrapper.h | 20 +- lib/mfcc-stream.cpp | 13 +- lib/mfcc-stream.h | 2 - lib/nnet-component.cpp | 14 +- lib/nnet-component.h | 14 +- lib/nnet-lib.h | 1 - lib/nnet-stream.h | 2 - lib/pipeline-detect.cpp | 10 +- lib/pipeline-detect.h | 9 - lib/pipeline-personal-enroll.h | 1 - lib/pipeline-template-cut.h | 2 - lib/pipeline-vad.cpp | 3 +- lib/pipeline-vad.h | 1 - lib/raw-energy-vad-stream.h | 2 - lib/raw-nnet-vad-stream.cpp | 1 - lib/raw-nnet-vad-stream.h | 2 - lib/snowboy-debug.cpp | 14 +- lib/snowboy-debug.h | 17 +- lib/snowboy-detect-c.cpp | 673 ++++++++++++++++++++++++++++++++ lib/snowboy-detect-c.h | 69 ++++ lib/snowboy-detect.cpp | 16 +- lib/snowboy-detect.h | 13 +- lib/snowboy-options.cpp | 3 - lib/snowboy-options.h | 16 +- lib/template-container.h | 1 - lib/template-detect-stream.h | 2 - lib/template-enroll-stream.h | 2 - lib/universal-detect-stream.cpp | 497 +++++++++++------------ lib/universal-detect-stream.h | 114 +++--- lib/vad-lib.h | 2 - lib/vad-state-stream.h | 2 - lib/vector-wrapper.cpp | 117 +++--- lib/vector-wrapper.h | 20 +- lib/{types.h => wave-header.h} | 0 test/EnrollTest.cpp | 6 - test/helper.cpp | 191 ++++++++- test/helper.h | 46 ++- 54 files changed, 1576 insertions(+), 664 deletions(-) create mode 100644 .editorconfig create mode 100644 lib/snowboy-detect-c.cpp create mode 100644 lib/snowboy-detect-c.h rename lib/{types.h => wave-header.h} (100%) diff --git a/.editorconfig b/.editorconfig new file mode 100644 index 0000000..33af44c --- /dev/null +++ b/.editorconfig @@ -0,0 +1,17 @@ +root = true + +[!*.{h,cpp}] +indent_style = space +indent_size = 2 +end_of_line = lf +charset = utf-8 +trim_trailing_whitespace = true +insert_final_newline = true + +[*.{h,cpp}] +indent_style = tab +trim_trailing_whitespace = true +insert_final_newline = true + +[*.md] +trim_trailing_whitespace = false diff --git a/CMakeLists.txt b/CMakeLists.txt index f5d8945..9f51667 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,20 +1,33 @@ -cmake_minimum_required(VERSION 3.10) -project(snowboy) +cmake_minimum_required(VERSION 3.12) +project(snowman VERSION 1.0.0 DESCRIPTION "Snowman hotword detection library") list(APPEND CMAKE_MODULE_PATH ${CMAKE_CURRENT_LIST_DIR}/cmake) -#option(SNOWBOY_CXX11_COMPAT "Build library with C++11 strings disable to be binary compatible with the original release." OFF) -set(SNOWBOY_CXX11_COMPAT ON) +option(SNOWMAN_CXX11_COMPAT "Build library with C++11 strings disabled to be binary compatible with the original release." ON) +option(SNOWMAN_BUILD_APPS "Build helper applications like enroll or cut" ON) +option(SNOWMAN_BUILD_TESTS "Build unit tests (requires gtest and openssl)" ON) +option(SNOWMAN_BUILD_SHARED "Build library as a shared library instead of a static library" OFF) +# According to steam ~99.17% of users have at least ssse3 +# and 100% have SSE, SSE2 and SSE3, so this is on by default +option(SNOWMAN_BUILD_WITH_SSE3 "Enable sse3 optimizations" ON) +# ~98.3%, so this is on by default as well +option(SNOWMAN_BUILD_WITH_SSE4 "Enable sse4 optimizations" ON) +# ~94.7% +option(SNOWMAN_BUILD_WITH_AVX "Enable avx optimizations" OFF) +# ~82% +option(SNOWMAN_BUILD_WITH_AVX2 "Enable avx2 optimizations" OFF) +option(SNOWMAN_BUILD_NATIVE "Build library for the current cpu. This makes sure it uses every instruction set available, but the resulting binary probably won't run on older hardware." OFF) -if(SNOWBOY_CXX11_COMPAT) -add_compile_options(-D_GLIBCXX_USE_CXX11_ABI=0) -endif() - -add_compile_options(-std=c++0x -Wall -Wno-sign-compare -Wno-unused-local-typedefs -Winit-self -rdynamic) -add_compile_options(-DHAVE_POSIX_MEMALIGN -I. -fno-omit-frame-pointer -fPIC -msse -msse2) add_compile_options("$<$:-fsanitize=address>") -add_link_options(-rdynamic) add_link_options("$<$:-fsanitize=address>") +if(SNOWMAN_CXX11_COMPAT) +add_compile_options(-D_GLIBCXX_USE_CXX11_ABI=0) +endif() + add_subdirectory(lib) -add_subdirectory(apps) -add_subdirectory(test) \ No newline at end of file +if(SNOWMAN_BUILD_APPS) + add_subdirectory(apps) +endif() +if(SNOWMAN_BUILD_TESTS) + add_subdirectory(test) +endif() \ No newline at end of file diff --git a/apps/cut.cpp b/apps/cut.cpp index 70c0a49..c303d6f 100644 --- a/apps/cut.cpp +++ b/apps/cut.cpp @@ -1,13 +1,9 @@ #include -#include #include #include #include -#include #include -#include -#include -#include +#include const static auto root = detect_project_root(); diff --git a/apps/enroll.cpp b/apps/enroll.cpp index aaa9746..24c0054 100644 --- a/apps/enroll.cpp +++ b/apps/enroll.cpp @@ -1,14 +1,7 @@ #include -#include -#include #include #include -#include -#include -#include #include -#include -#include const static auto root = detect_project_root(); @@ -20,9 +13,6 @@ int main(int argc, const char** argv) { bool cut_recordings; if (!parse_args(argc, argv, output, recordings, lang, cut_recordings)) return -1; - { - std::ofstream t{output, std::ios::binary | std::ios::trunc}; - } snowboy::SnowboyPersonalEnroll enroll{root + "resources/pmdl/" + lang + "/personal_enroll.res", output}; snowboy::SnowboyTemplateCut cut{root + "resources/pmdl/" + lang + "/personal_enroll.res"}; for (auto& e : recordings) { diff --git a/lib/CMakeLists.txt b/lib/CMakeLists.txt index fb7fd90..adf653c 100644 --- a/lib/CMakeLists.txt +++ b/lib/CMakeLists.txt @@ -1,5 +1,5 @@ -add_library(snowboy_reimpl +set(SNOWMAN_SRC ${CMAKE_CURRENT_SOURCE_DIR}/agc.cpp ${CMAKE_CURRENT_SOURCE_DIR}/audio-lib.cpp ${CMAKE_CURRENT_SOURCE_DIR}/dtw-lib.cpp @@ -28,6 +28,7 @@ add_library(snowboy_reimpl ${CMAKE_CURRENT_SOURCE_DIR}/raw-nnet-vad-stream.cpp ${CMAKE_CURRENT_SOURCE_DIR}/snowboy-debug.cpp ${CMAKE_CURRENT_SOURCE_DIR}/snowboy-detect.cpp + ${CMAKE_CURRENT_SOURCE_DIR}/snowboy-detect-c.cpp ${CMAKE_CURRENT_SOURCE_DIR}/snowboy-io.cpp ${CMAKE_CURRENT_SOURCE_DIR}/snowboy-math.cpp ${CMAKE_CURRENT_SOURCE_DIR}/snowboy-options.cpp @@ -41,11 +42,35 @@ add_library(snowboy_reimpl ${CMAKE_CURRENT_SOURCE_DIR}/vad-state-stream.cpp ${CMAKE_CURRENT_SOURCE_DIR}/vector-wrapper.cpp ) -target_include_directories(snowboy_reimpl PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}) -target_link_directories(snowboy_reimpl PRIVATE /usr/lib/atlas-base) -target_link_libraries(snowboy_reimpl dl m rt pthread f77blas cblas lapack_atlas atlas) +if(SNOWMAN_BUILD_SHARED) + add_library(snowman SHARED ${SNOWMAN_SRC}) + set_target_properties(snowman PROPERTIES VERSION ${PROJECT_VERSION}) + set_target_properties(mylib PROPERTIES SOVERSION ${CMAKE_PROJECT_VERSION_MAJOR}) +else() + add_library(snowman STATIC ${SNOWMAN_SRC}) +endif() +target_compile_features(snowman PRIVATE cxx_std_11) +target_compile_options(snowman PRIVATE -Wall -Wno-sign-compare -Winit-self -rdynamic) +target_compile_options(snowman PRIVATE -DHAVE_POSIX_MEMALIGN -fno-omit-frame-pointer) +target_include_directories(snowman PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}) +target_link_directories(snowman PRIVATE /usr/lib/atlas-base) +target_link_libraries(snowman m pthread f77blas cblas lapack_atlas atlas) + +if(SNOWMAN_BUILD_WITH_SSE3) +target_compile_options(snowman PRIVATE -msse -msse2 -msse3 -mssse3) +endif() +if(SNOWMAN_BUILD_WITH_SSE4) +target_compile_options(snowman PRIVATE -msse4.2) +endif() +if(SNOWMAN_BUILD_WITH_AVX) +target_compile_options(snowman PRIVATE -mavx) +endif() +if(SNOWMAN_BUILD_WITH_AVX2) +target_compile_options(snowman PRIVATE -mavx2) +endif() +if(SNOWMAN_BUILD_NATIVE) +target_compile_options(snowman PRIVATE -march=native -mtune=native) +endif() + +add_library(snowboy ALIAS snowman) -file(TOUCH ${CMAKE_CURRENT_BINARY_DIR}/dummy.cpp) -add_library(snowboy ${CMAKE_CURRENT_BINARY_DIR}/dummy.cpp) -target_link_libraries(snowboy -Wl,--start-group snowboy_reimpl -Wl,--end-group) -#target_link_libraries(snowboy -Wl,--start-group snowboy_reimpl ${CMAKE_CURRENT_SOURCE_DIR}/../resources/nnet-lib.o -Wl,--end-group) diff --git a/lib/audio-lib.cpp b/lib/audio-lib.cpp index b58310b..a2d85c2 100644 --- a/lib/audio-lib.cpp +++ b/lib/audio-lib.cpp @@ -1,7 +1,7 @@ #include #include #include -#include +#include namespace snowboy { float GetMaxWaveAmplitude(const WaveHeader& hdr) { diff --git a/lib/dtw-lib.cpp b/lib/dtw-lib.cpp index 4e76657..094f785 100644 --- a/lib/dtw-lib.cpp +++ b/lib/dtw-lib.cpp @@ -109,7 +109,7 @@ namespace snowboy { auto local_22c = std::numeric_limits::max(); for (auto row = 0; row < param_2.m_rows; row++) { /* try { // try from 00101d18 to 00101d74 has its CatchHandler @ 00102283 */ - int local_1e8, local_1e4, local_1e0, local_1dc; + int local_1e8 = 0, local_1e4 = 0, local_1e0 = 0, local_1dc = 0; snowboy::SlidingDtw::ComputeBandBoundary(row, &local_1e8, &local_1e4); if (0 < row) { snowboy::SlidingDtw::ComputeBandBoundary(row - 1, &local_1e0, &local_1dc); @@ -200,22 +200,16 @@ namespace snowboy { Matrix local_1d8; local_1d8.Resize(param_2.m_rows, param_3.m_rows); for (auto row = 0; row != local_1d8.m_rows; row++) { - const int iVar7 = (int(row) - 1) * (int)local_1d8.m_stride; - const int iVar14 = row * local_1f8.m_stride; - const int local_284 = row * local_1f8.m_stride; - const int iVar11 = row * local_1d8.m_stride; - const int iVar15 = row * local_1d8.m_stride; if (0 < local_1d8.m_cols) { - auto pfVar9 = local_1d8.m_data + (long)iVar7 + -1; - auto pfVar8 = local_1d8.m_data + (long)iVar11 + -1; - if (row == 0) { - auto lVar12 = 0; - do { + auto pfVar9 = local_1d8.m_data + (row - 1) * local_1d8.m_stride - 1; + auto pfVar8 = local_1d8.m_data + (row * local_1d8.m_stride) - 1; + for (auto lVar12 = 0; lVar12 < local_1d8.m_cols; lVar12++) { + if (row == 0) { while (((int)lVar12 == 0 || (row == 0))) { - pfVar9 = pfVar9 + 1; local_1d8.m_data[lVar12] = local_1f8.m_data[lVar12]; - lVar12 += 1; + pfVar9 = pfVar9 + 1; pfVar8 = pfVar8 + 1; + lVar12++; if (local_1d8.m_cols <= (int)lVar12) goto LAB_00186e9d; } auto fVar16 = pfVar9[1]; @@ -225,18 +219,12 @@ namespace snowboy { if (*pfVar9 <= fVar16) { fVar16 = *pfVar9; } - auto lVar6 = iVar14 + lVar12; - lVar12 += 1; - pfVar8[1] = fVar16 + local_1f8.m_data[lVar6]; - pfVar9 = pfVar9 + 1; - pfVar8 = pfVar8 + 1; - } while ((int)lVar12 < local_1d8.m_cols); - } else { - auto lVar12 = 0; - do { + pfVar8[1] = fVar16 + local_1f8.m_data[row * local_1f8.m_stride + lVar12]; + + } else { if (((int)lVar12 == 0) || (row == 0)) { if ((int)lVar12 == 0) { - local_1d8.m_data[iVar15] = local_1f8.m_data[local_284] + local_1d8.m_data[iVar15 - local_1d8.m_stride]; + local_1d8.m_data[row * local_1d8.m_stride] = local_1f8.m_data[row * local_1f8.m_stride] + local_1d8.m_data[(row - 1) * local_1d8.m_stride]; } } else { auto fVar16 = pfVar9[1]; @@ -246,40 +234,37 @@ namespace snowboy { if (*pfVar9 <= fVar16) { fVar16 = *pfVar9; } - pfVar8[1] = fVar16 + local_1f8.m_data[iVar14 + lVar12]; + pfVar8[1] = fVar16 + local_1f8.m_data[row * local_1f8.m_stride + lVar12]; } - lVar12 += 1; - pfVar9 = pfVar9 + 1; - pfVar8 = pfVar8 + 1; - } while ((int)lVar12 < local_1d8.m_cols); + } + pfVar9 = pfVar9 + 1; + pfVar8 = pfVar8 + 1; } } LAB_00186e9d: []() {}(); // TODO: This is just here cause for some reason a label directly before the closing bracket does not work } - auto local_228 = -1; - int iVar11 = local_1d8.m_rows - 1; - SubVector{local_1d8, iVar11}.Min(&local_228); - auto fVar16 = local_1d8.m_data[local_1d8.m_stride * iVar11 + local_228]; + auto min_index = -1; + auto min_value = SubVector{local_1d8, local_1d8.m_rows - 1}.Min(&min_index); if (param_4 != nullptr) { - while (iVar11 != 0) { + for (int iVar11 = local_1d8.m_rows - 1; iVar11 != 0;) { // TODO: This is wrong // If I look at the code it should only be - // param_4->at(iVar11).push_back(local_228); + // param_4->at(iVar11).push_back(min_index); // But that produces different results from what it should if (param_4->at(iVar11).empty()) - param_4->at(iVar11).push_back(local_228); + param_4->at(iVar11).push_back(min_index); else - param_4->at(iVar11).at(0) = local_228; - if (0 >= local_228) { + param_4->at(iVar11).at(0) = min_index; + if (0 >= min_index) { iVar11--; continue; } - auto fVar18 = local_1d8.m_data[local_1d8.m_stride * iVar11 + local_228] - local_1f8.m_data[local_1f8.m_stride * iVar11 + local_228]; - float pfVar8[3] = {fVar18, fVar18, fVar18}; - pfVar8[0] = std::abs(fVar18 - local_1d8.m_data[(iVar11 + -1) * local_1d8.m_stride + (local_228 - 1)]); - pfVar8[1] = std::abs(fVar18 - local_1d8.m_data[(local_228 - 1) + iVar11 * local_1d8.m_stride]); - pfVar8[2] = std::abs(fVar18 - local_1d8.m_data[(iVar11 + -1) * local_1d8.m_stride + local_228]); + auto fVar18 = local_1d8.m_data[local_1d8.m_stride * iVar11 + min_index] - local_1f8.m_data[local_1f8.m_stride * iVar11 + min_index]; + float pfVar8[3]; + pfVar8[0] = std::abs(fVar18 - local_1d8.m_data[(iVar11 + -1) * local_1d8.m_stride + (min_index - 1)]); + pfVar8[1] = std::abs(fVar18 - local_1d8.m_data[(min_index - 1) + iVar11 * local_1d8.m_stride]); + pfVar8[2] = std::abs(fVar18 - local_1d8.m_data[(iVar11 + -1) * local_1d8.m_stride + min_index]); auto pfVar9 = pfVar8 + 1; if (pfVar8[0] <= pfVar8[1]) { pfVar9 = pfVar8; @@ -290,26 +275,26 @@ namespace snowboy { auto iVar10 = (int)((long)((long)pfVar9 - (long)pfVar8) >> 2); if (iVar10 != 0) { if (iVar10 == 1) { - local_228 -= 1; + min_index -= 1; } else { if (iVar10 == 2) { - iVar11 = iVar11 + -1; + iVar11--; } } } else { - local_228 -= 1; - iVar11 = iVar11 + -1; + min_index -= 1; + iVar11--; } } // TODO: This is wrong // If I look at the code it should only be - // param_4->at(0).push_back(local_228); + // param_4->at(0).push_back(min_index); // But that produces different results from what it should if (param_4->at(0).empty()) - param_4->at(0).push_back(local_228); + param_4->at(0).push_back(min_index); else - param_4->at(0).at(0) = local_228; + param_4->at(0).at(0) = min_index; } - return fVar16 / param_2.m_rows; + return min_value / param_2.m_rows; } } // namespace snowboy \ No newline at end of file diff --git a/lib/dtw-lib.h b/lib/dtw-lib.h index 9cbe827..d491443 100644 --- a/lib/dtw-lib.h +++ b/lib/dtw-lib.h @@ -16,7 +16,6 @@ namespace snowboy { // TODO: This could be replaced with enum DistanceType std::string distance_metric; }; - static_assert(sizeof(SlidingDtwOptions) == 0x10); struct SlidingDtw { SlidingDtwOptions m_options; std::deque> field_x18; @@ -38,7 +37,6 @@ namespace snowboy { void ComputeBandBoundary(int, int*, int*) const; virtual ~SlidingDtw(); }; - static_assert(sizeof(SlidingDtw) == 0x78); float DtwAlign(DistanceType, const MatrixBase&, const MatrixBase&, std::vector>*); } // namespace snowboy \ No newline at end of file diff --git a/lib/eavesdrop-stream.h b/lib/eavesdrop-stream.h index a343f81..d7fbcdd 100644 --- a/lib/eavesdrop-stream.h +++ b/lib/eavesdrop-stream.h @@ -14,5 +14,4 @@ namespace snowboy { virtual std::string Name() const override; virtual ~EavesdropStream(); }; - static_assert(sizeof(EavesdropStream) == 0x28); } // namespace snowboy \ No newline at end of file diff --git a/lib/feat-lib.cpp b/lib/feat-lib.cpp index 684175e..ac67247 100644 --- a/lib/feat-lib.cpp +++ b/lib/feat-lib.cpp @@ -22,13 +22,11 @@ namespace snowboy { InitMelFilterBank(); } - MelFilterBank::~MelFilterBank() {} - void MelFilterBank::InitMelFilterBank() { // TODO: This might contain bugs and generally needs a proper rewrite, but I dont know enough about audio processing to do it field_x28.resize(m_options.num_bins, 0); field_x40.resize(m_options.num_bins); - auto fVar13 = logf(m_options.low_frequency / 700.0f + 1.0f); + const auto fVar13 = logf(m_options.low_frequency / 700.0f + 1.0f); auto fVar14 = logf(m_options.high_frequency / 700.0f + 1.0f); fVar14 = (fVar14 * 1127.0 - fVar13 * 1127.0) / static_cast(m_options.num_bins + 1); auto fVar17 = static_cast(m_options.sample_rate) / static_cast(m_options.num_fft_points); @@ -97,11 +95,10 @@ namespace snowboy { } } - void MelFilterBank::ComputeMelFilterBankEnergy(const VectorBase& param_1, Vector* param_2) const { - if (m_options.num_bins != param_2->m_size) param_2->Resize(m_options.num_bins); - for (int b = 0; b < param_2->m_size; b++) { - auto f = field_x40[b].DotVec(param_1.Range(field_x28[b], field_x40[b].m_size)); - param_2->m_data[b] = f; + void MelFilterBank::ComputeMelFilterBankEnergy(const VectorBase& input, Vector& param_2) const { + if (m_options.num_bins != param_2.size()) param_2.Resize(m_options.num_bins); + for (int b = 0; b < param_2.size(); b++) { + param_2[b] = field_x40[b].DotVec(input.Range(field_x28[b], field_x40[b].size())); } } @@ -131,15 +128,13 @@ namespace snowboy { } } - void ComputePowerSpectrumReal(Vector* data) { - if (data->m_size == 0) return; - auto ptr = data->m_data; - float f = ptr[0] * ptr[0]; - for (int i = 0; i < data->m_size / 2 - 1; i++) { - ptr[i + 1] = ptr[i * 2 + 3] * ptr[i * 2 + 3] + ptr[i * 2 + 2] * ptr[i * 2 + 2]; + void ComputePowerSpectrumReal(Vector& data) { + if (data.empty()) return; + data[0] = data[0] * data[0]; + for (int i = 0; i < data.size() / 2 - 1; i++) { + data[i + 1] = data[i * 2 + 3] * data[i * 2 + 3] + data[i * 2 + 2] * data[i * 2 + 2]; } - ptr[0] = f; - data->Resize(data->m_size / 2, MatrixResizeType::kCopyData); + data.Resize(data.size() / 2, MatrixResizeType::kCopyData); } FftItf::~FftItf() {} @@ -175,12 +170,11 @@ namespace snowboy { auto lVar14 = (long)local_68 * 2 + 1 + iVar9; auto lVar15 = (long)(local_68 * 2); for (auto iVar11 = 0; iVar11 != iVar9 / 2; iVar11++) { - float local_40, local_3c; - snowboy::Fft::GetTwiddleFactor(iVar9, iVar11, &local_40, &local_3c); - if (inverse) local_3c *= -1; + auto twiddle = snowboy::Fft::GetTwiddleFactor(iVar9, iVar11); + if (inverse) twiddle.second *= -1; auto fVar16 = pfVar5[lVar15 + iVar9]; - auto fVar18 = fVar16 * local_40 - local_3c * pfVar5[lVar14]; - fVar16 = pfVar5[lVar14] * local_40 + fVar16 * local_3c; + auto fVar18 = fVar16 * twiddle.first - twiddle.second * pfVar5[lVar14]; + fVar16 = pfVar5[lVar14] * twiddle.first + fVar16 * twiddle.second; pfVar5[lVar15 + iVar9] = pfVar5[lVar15] - fVar18; pfVar5[lVar14] = pfVar5[lVar15 + 1] - fVar16; pfVar5[lVar15] += fVar18; @@ -257,17 +251,16 @@ namespace snowboy { auto lVar12 = 2; auto lVar9 = num_pts; for (auto iVar13 = 1; iVar13 <= iVar11 / 4; iVar13 += 1) { - float twiddle_a, twiddle_b; - snowboy::Fft::GetTwiddleFactor(num_pts, param_1 ? (static_cast(num_pts) * 0.5 - iVar13) : iVar13, &twiddle_a, &twiddle_b); + const auto twiddle = snowboy::Fft::GetTwiddleFactor(num_pts, param_1 ? (static_cast(num_pts) * 0.5 - iVar13) : iVar13); const auto fVar4 = ptr[lVar9 - 1]; const auto fVar5 = ptr[lVar9 - 2]; const auto fVar6 = ptr[lVar12]; const auto fVar7 = ptr[lVar12 + 1]; - ptr[lVar12] = (twiddle_a * fVar7 + (twiddle_b + 1.0) * fVar6 + (1.0 - twiddle_b) * fVar5 + fVar4 * twiddle_a) * 0.5; - ptr[lVar12 + 1] = ((twiddle_b + 1.0) * fVar7 + ((fVar5 * twiddle_a - (1.0 - twiddle_b) * fVar4) - twiddle_a * fVar6)) * 0.5; + ptr[lVar12] = (twiddle.first * fVar7 + (twiddle.second + 1.0) * fVar6 + (1.0 - twiddle.second) * fVar5 + fVar4 * twiddle.first) * 0.5; + ptr[lVar12 + 1] = ((twiddle.second + 1.0) * fVar7 + ((fVar5 * twiddle.first - (1.0 - twiddle.second) * fVar4) - twiddle.first * fVar6)) * 0.5; if (iVar13 * 2 != lVar9 - 2) { - ptr[lVar9 - 2] = ((((twiddle_b + 1.0) * fVar5 - fVar4 * twiddle_a) + (1.0 - twiddle_b) * fVar6) - twiddle_a * fVar7) * 0.5; - ptr[lVar9 - 1] = (((fVar5 * twiddle_a + fVar4 * (twiddle_b + 1.0)) - fVar6 * twiddle_a) - fVar7 * (1.0 - twiddle_b)) * 0.5; + ptr[lVar9 - 2] = ((((twiddle.second + 1.0) * fVar5 - fVar4 * twiddle.first) + (1.0 - twiddle.second) * fVar6) - twiddle.first * fVar7) * 0.5; + ptr[lVar9 - 1] = (((fVar5 * twiddle.first + fVar4 * (twiddle.second + 1.0)) - fVar6 * twiddle.first) - fVar7 * (1.0 - twiddle.second)) * 0.5; } lVar12 += 2; lVar9 -= 2; @@ -280,26 +273,16 @@ namespace snowboy { } unsigned int Fft::GetNumBits(unsigned int param_1) const { - // Note: This is the original function, but using __builtin_clz should generate way faster code - //unsigned int res = 0; - //while(param_1 > 1) { - // res += 1; - // param_1 >>= 1; - //} - //return res; - // TODO: Use C++20 bit operations once we build for real return param_1 == 0 ? 0 : (31 - __builtin_clz(param_1)); } - void Fft::GetTwiddleFactor(int param_1, int param_2, float* param_3, float* param_4) const { + std::pair Fft::GetTwiddleFactor(int param_1, int param_2) const { auto size = m_twiddle_factors.size(); auto idx = (size / param_1) * param_2 * 2; if (idx < size) { - *param_3 = m_twiddle_factors[idx]; - *param_4 = m_twiddle_factors[idx + 1]; + return {m_twiddle_factors[idx], m_twiddle_factors[idx + 1]}; } else { - *param_3 = m_twiddle_factors[size - idx] * -1; - *param_3 = m_twiddle_factors[(size - idx) + 1] * -1; + return {m_twiddle_factors[size - idx] * -1.0f, m_twiddle_factors[(size - idx) + 1] * -1.0f}; } } diff --git a/lib/feat-lib.h b/lib/feat-lib.h index 8ab2e47..28ac445 100644 --- a/lib/feat-lib.h +++ b/lib/feat-lib.h @@ -18,26 +18,26 @@ namespace snowboy { void Register(const std::string& prefix, OptionsItf* opts); }; - static_assert(sizeof(MelFilterBankOptions) == 0x20); - struct MelFilterBank { + class MelFilterBank { MelFilterBankOptions m_options; // Both hold num_bins entries std::vector field_x28; std::vector field_x40; - MelFilterBank(const MelFilterBankOptions& options); - virtual ~MelFilterBank(); // TODO: Does not need to be virtual (imho) but we keep it that way to make sure the layout matches void InitMelFilterBank(); float GetVtlnWarping(float) const; - void ComputeMelFilterBankEnergy(const VectorBase&, Vector*) const; void ValidateOptions() const; + + public: + MelFilterBank(const MelFilterBankOptions& options); + ~MelFilterBank() {} + void ComputeMelFilterBankEnergy(const VectorBase& input, Vector& output) const; }; - static_assert(sizeof(MelFilterBank) == 0x58); void ComputeDctMatrixTypeIII(Matrix* mat); void ComputeCepstralLifterCoeffs(float, Vector*); - void ComputePowerSpectrumReal(Vector*); + void ComputePowerSpectrumReal(Vector&); struct FftItf { virtual void DoFft(Vector*) const = 0; @@ -64,7 +64,7 @@ namespace snowboy { void ComputeBitReversalIndex(int, std::vector*) const; void DoProcessingForReal(bool, Vector*) const; unsigned int GetNumBits(unsigned int) const; - void GetTwiddleFactor(int, int, float*, float*) const; + std::pair GetTwiddleFactor(int, int) const; void Init(); unsigned int ReverseBit(unsigned int, unsigned int) const; void SetOptions(const FftOptions& opts); @@ -73,7 +73,6 @@ namespace snowboy { virtual void DoIfft(Vector*) const override; virtual ~Fft(); }; - static_assert(sizeof(Fft) == 0x48); struct SplitRadixFft : FftItf { FftOptions m_options; @@ -96,5 +95,4 @@ namespace snowboy { virtual void DoIfft(Vector*) const override; virtual ~SplitRadixFft(); }; - static_assert(sizeof(SplitRadixFft) == 0x48); } // namespace snowboy \ No newline at end of file diff --git a/lib/fft-stream.h b/lib/fft-stream.h index 7b740e1..31a475b 100644 --- a/lib/fft-stream.h +++ b/lib/fft-stream.h @@ -10,7 +10,6 @@ namespace snowboy { std::string method; void Register(const std::string& prefix, OptionsItf* options); }; - static_assert(sizeof(FftStreamOptions) == 0x10); struct FftStream : StreamItf { FftStreamOptions m_options; std::unique_ptr m_fft; @@ -24,6 +23,5 @@ namespace snowboy { virtual std::string Name() const override; virtual ~FftStream(); }; - static_assert(sizeof(FftStream) == 0x38); } // namespace snowboy \ No newline at end of file diff --git a/lib/framer-stream.h b/lib/framer-stream.h index 5f58cbe..6a5fda8 100644 --- a/lib/framer-stream.h +++ b/lib/framer-stream.h @@ -16,7 +16,6 @@ namespace snowboy { std::string window_type; void Register(const std::string&, OptionsItf*); }; - static_assert(sizeof(FramerStreamOptions) == 0x20); struct FramerStream : StreamItf { FramerStreamOptions m_options; int field_x38; @@ -36,5 +35,4 @@ namespace snowboy { virtual std::string Name() const override; virtual ~FramerStream(); }; - static_assert(sizeof(FramerStream) == 0x68); } // namespace snowboy \ No newline at end of file diff --git a/lib/frontend-stream.h b/lib/frontend-stream.h index cf6ed41..1ddd221 100644 --- a/lib/frontend-stream.h +++ b/lib/frontend-stream.h @@ -14,7 +14,6 @@ namespace snowboy { std::string agc_power; void Register(const std::string&, OptionsItf*); }; - static_assert(sizeof(FrontendStreamOptions) == 0x20); struct FrontendStream : StreamItf { std::string m_ns_power; std::string m_dr_power; @@ -33,5 +32,4 @@ namespace snowboy { virtual std::string Name() const override; virtual ~FrontendStream(); }; - static_assert(sizeof(FrontendStream) == 0x68); } // namespace snowboy \ No newline at end of file diff --git a/lib/gain-control-stream.h b/lib/gain-control-stream.h index c68722e..2f26119 100644 --- a/lib/gain-control-stream.h +++ b/lib/gain-control-stream.h @@ -7,7 +7,6 @@ namespace snowboy { float m_audioGain; void Register(const std::string&, OptionsItf*); }; - static_assert(sizeof(GainControlStreamOptions) == 4); struct GainControlStream : StreamItf { float m_audioGain; float m_maxAudioAmplitude; @@ -21,5 +20,4 @@ namespace snowboy { void SetAudioGain(float gain); void SetMaxAudioAmplitude(float amp); }; - static_assert(sizeof(GainControlStream) == 0x20); } // namespace snowboy \ No newline at end of file diff --git a/lib/intercept-stream.h b/lib/intercept-stream.h index 71d3855..3f7fc38 100644 --- a/lib/intercept-stream.h +++ b/lib/intercept-stream.h @@ -19,5 +19,4 @@ namespace snowboy { void ReadData(Matrix* mat, std::vector* info, SnowboySignal* signal); void SetData(const MatrixBase& mat, const std::vector& info, const SnowboySignal& signal); }; - static_assert(sizeof(InterceptStream) == 0x108); } // namespace snowboy \ No newline at end of file diff --git a/lib/matrix-wrapper.cpp b/lib/matrix-wrapper.cpp index 05d4c00..f88c152 100644 --- a/lib/matrix-wrapper.cpp +++ b/lib/matrix-wrapper.cpp @@ -250,15 +250,36 @@ namespace snowboy { static size_t allocs = 0; static size_t frees = 0; + template + constexpr inline T next_multiple_of(T val, T multi) noexcept { + return (val + multi - 1) & ~(multi - 1); + } + void Matrix::Resize(int rows, int cols, MatrixResizeType resize) { - // TODO: Smarter alloc similar to vector - if (m_rows == rows && m_cols == cols) { - if (resize == MatrixResizeType::kSetZero) Set(0.0f); + if (cols == 0 && rows == 0) { + m_rows = 0; + m_cols = 0; return; } + // TODO: Smarter alloc similar to vector + uint64_t mem_size = static_cast(m_rows) * static_cast(m_stride); + uint64_t new_size = static_cast(rows) * next_multiple_of(cols, 4); + if (new_size <= mem_size) { + if (resize == MatrixResizeType::kUndefined || resize == MatrixResizeType::kSetZero) { + m_rows = rows; + m_cols = cols; + m_stride = next_multiple_of(cols, 4); + if (resize == MatrixResizeType::kSetZero) Set(0.0f); + return; + } else if (cols <= m_stride) { + m_rows = rows; + m_cols = cols; + return; + } + } if (m_data == nullptr) { AllocateMatrixMemory(rows, cols); - Set(0.0f); + if (resize == MatrixResizeType::kSetZero) Set(0.0f); return; } if (resize == MatrixResizeType::kCopyData) { diff --git a/lib/matrix-wrapper.h b/lib/matrix-wrapper.h index f405809..9e1f0d5 100644 --- a/lib/matrix-wrapper.h +++ b/lib/matrix-wrapper.h @@ -14,6 +14,13 @@ namespace snowboy { uint32_t m_stride{0}; float* m_data{nullptr}; + size_t rows() const noexcept { return m_rows; } + size_t cols() const noexcept { return m_cols; } + size_t stride() const noexcept { return m_stride; } + float* data() const noexcept { return m_data; } + float& operator()(size_t row, size_t col) const noexcept { return m_data[row * m_stride + col]; } + bool empty() const noexcept { return rows() == 0 || cols() == 0; } + void AddMat(float alpha, const MatrixBase& A, MatrixTransposeType transA); void AddMatMat(float, const MatrixBase&, MatrixTransposeType, const MatrixBase&, MatrixTransposeType, float); void AddVecToRows(float, const VectorBase&); @@ -76,15 +83,7 @@ namespace snowboy { Matrix& operator=(const Matrix& other); Matrix& operator=(const MatrixBase& other); Matrix& operator=(Matrix&& other) { - ReleaseMatrixMemory(); - m_rows = other.m_rows; - m_cols = other.m_cols; - m_stride = other.m_stride; - m_data = other.m_data; - other.m_rows = 0; - other.m_data = nullptr; - other.m_stride = 0; - other.m_cols = 0; + Swap(&other); return *this; } @@ -100,9 +99,6 @@ namespace snowboy { struct SubMatrix : MatrixBase { SubMatrix(const MatrixBase& parent, int rowoffset, int rows, int coloffset, int cols); }; - static_assert(sizeof(MatrixBase) == 0x18); - static_assert(sizeof(Matrix) == 0x18); - static_assert(sizeof(SubMatrix) == 0x18); std::ostream& operator<<(std::ostream&, const MatrixBase&); } // namespace snowboy \ No newline at end of file diff --git a/lib/mfcc-stream.cpp b/lib/mfcc-stream.cpp index 85deea3..827e90d 100644 --- a/lib/mfcc-stream.cpp +++ b/lib/mfcc-stream.cpp @@ -17,17 +17,14 @@ namespace snowboy { m_options = options; field_x44 = -1; field_x48 = 0.0f; + // TODO: Can we optimize this matrix ? Matrix m; m.Resize(m_options.mel_filter.num_bins, m_options.mel_filter.num_bins); - Vector v; ComputeDctMatrixTypeIII(&m); - v.Resize(m_options.num_cepstral_coeffs); - ComputeCepstralLifterCoeffs(m_options.cepstral_lifter, &v); + m_cepstral_coeffs.Resize(m_options.num_cepstral_coeffs, MatrixResizeType::kUndefined); + ComputeCepstralLifterCoeffs(m_options.cepstral_lifter, &m_cepstral_coeffs); m_dct_matrix.Resize(m_options.num_cepstral_coeffs, m_options.mel_filter.num_bins); m_dct_matrix.CopyFromMat(m.RowRange(0, m_options.num_cepstral_coeffs), MatrixTransposeType::kNoTrans); - // TODO: These two could probably be a move - m_cepstral_coeffs.Resize(m_options.num_cepstral_coeffs); - m_cepstral_coeffs.CopyFromVec(v); } int MfccStream::Read(Matrix* mat, std::vector* info) { @@ -87,9 +84,9 @@ namespace snowboy { Vector v; v.Resize(param_1.m_size); v.CopyFromVec(param_1); - ComputePowerSpectrumReal(&v); + ComputePowerSpectrumReal(v); Vector vout; - m_melfilterbank->ComputeMelFilterBankEnergy(v, &vout); + m_melfilterbank->ComputeMelFilterBankEnergy(v, vout); vout.ApplyFloor(std::numeric_limits::min()); vout.ApplyLog(); param_2->AddMatVec(1.0, m_dct_matrix, MatrixTransposeType::kNoTrans, vout, 0.0); diff --git a/lib/mfcc-stream.h b/lib/mfcc-stream.h index 5f07ddf..d380fd7 100644 --- a/lib/mfcc-stream.h +++ b/lib/mfcc-stream.h @@ -13,7 +13,6 @@ namespace snowboy { void Register(const std::string& prefix, OptionsItf* opts); }; - static_assert(sizeof(MfccStreamOptions) == 0x2c); struct MfccStream : StreamItf { MfccStreamOptions m_options; int field_x44; @@ -31,5 +30,4 @@ namespace snowboy { void InitMelFilterBank(int); void ComputeMfcc(const VectorBase&, SubVector*) const; }; - static_assert(sizeof(MfccStream) == 0x80); } // namespace snowboy \ No newline at end of file diff --git a/lib/nnet-component.cpp b/lib/nnet-component.cpp index b01c703..a6106a6 100644 --- a/lib/nnet-component.cpp +++ b/lib/nnet-component.cpp @@ -299,7 +299,6 @@ namespace snowboy { in_info.CheckSize(in); out_info.CheckSize(*out); - //SNOWBOY_ERROR() << "Unimplemented"; for (size_t r = 0; r < in.m_rows; r++) { if (out->m_cols < 2) @@ -308,19 +307,16 @@ namespace snowboy { } else { // TODO: I did my best but between here auto ptr = out->m_data + (out->m_stride * r); - float x = 0.0; - auto ptr2 = ptr; + float sum = 0.0f; for (auto& idx_vec : m_indices) { - ptr2++; - auto f = *ptr2; + ptr++; for (auto idx : idx_vec) { auto v = in.m_data[in.m_stride * r + idx]; - f += v; - x += v; + *ptr += v; + sum += v; } - *ptr2 = f; } - out->m_data[out->m_stride * r] = 1.0 - x; + out->m_data[out->m_stride * r] = 1.0 - sum; // TODO: and here are probably a number of bugs } } diff --git a/lib/nnet-component.h b/lib/nnet-component.h index 5533c43..a00ea1a 100644 --- a/lib/nnet-component.h +++ b/lib/nnet-component.h @@ -8,7 +8,6 @@ namespace snowboy { struct MatrixBase; - // TODO: This is kaldi::ChunkInfo class ChunkInfo { int32_t m_feat_dim; int32_t m_num_chunks; @@ -18,13 +17,13 @@ namespace snowboy { friend std::ostream& operator<<(std::ostream& os, const ChunkInfo& e); public: - ChunkInfo() // default constructor we assume this object will not be used + ChunkInfo() noexcept // default constructor we assume this object will not be used : m_feat_dim(0), m_num_chunks(0), m_first_offset(0), m_last_offset(0), m_offsets() {} ChunkInfo(int32_t feat_dim, int32_t num_chunks, - int32_t first_offset, int32_t last_offset) + int32_t first_offset, int32_t last_offset) noexcept : m_feat_dim(feat_dim), m_num_chunks(num_chunks), m_first_offset(first_offset), m_last_offset(last_offset), m_offsets() { Check(); } @@ -55,7 +54,6 @@ namespace snowboy { Check(); } }; - static_assert(sizeof(ChunkInfo) == 0x28); std::ostream& operator<<(std::ostream& os, const ChunkInfo& e); class Component { @@ -89,7 +87,6 @@ namespace snowboy { Component(Component&&) = delete; Component& operator=(Component&&) = delete; }; - static_assert(sizeof(Component) == 0x10); // Actually 0xc, but padding.... class AffineComponent : public Component { bool m_is_gradient = 1; @@ -110,7 +107,6 @@ namespace snowboy { virtual Component* Copy() const override; virtual ~AffineComponent() {} }; - static_assert(sizeof(AffineComponent) == 0x38); // m_is_gradient is inside the padding of Component class CmvnComponent : public Component { bool field_xc = 0; @@ -131,7 +127,6 @@ namespace snowboy { virtual Component* Copy() const override; virtual ~CmvnComponent() {} }; - static_assert(sizeof(CmvnComponent) == 0x30); // field_xc is inside the padding of Component class NormalizeComponent : public Component { int32_t m_dim = 0; @@ -152,7 +147,6 @@ namespace snowboy { virtual Component* Copy() const override; virtual ~NormalizeComponent() {} }; - static_assert(sizeof(NormalizeComponent) == 0x18); // m_dim is inside the padding of Component class PosteriorMapComponent : public Component { bool field_xc; @@ -174,7 +168,6 @@ namespace snowboy { virtual Component* Copy() const override; virtual ~PosteriorMapComponent() {} }; - static_assert(sizeof(PosteriorMapComponent) == 0x30); // field_xc is inside the padding of Component class RectifiedLinearComponent : public Component { int32_t m_dim; @@ -194,7 +187,6 @@ namespace snowboy { virtual Component* Copy() const override; virtual ~RectifiedLinearComponent() {} }; - static_assert(sizeof(RectifiedLinearComponent) == 0x18); // m_dim is inside the padding of Component class SoftmaxComponent : public Component { int32_t m_dim; @@ -214,7 +206,6 @@ namespace snowboy { virtual Component* Copy() const override; virtual ~SoftmaxComponent() {} }; - static_assert(sizeof(SoftmaxComponent) == 0x18); // field_xc is inside the padding of Component class SpliceComponent : public Component { bool field_xc; @@ -238,5 +229,4 @@ namespace snowboy { virtual Component* Copy() const override; virtual ~SpliceComponent() {} }; - static_assert(sizeof(SpliceComponent) == 0x30); // field_xc is inside the padding of Component } // namespace snowboy \ No newline at end of file diff --git a/lib/nnet-lib.h b/lib/nnet-lib.h index 2997bfe..6b10a30 100644 --- a/lib/nnet-lib.h +++ b/lib/nnet-lib.h @@ -55,5 +55,4 @@ namespace snowboy { int32_t LeftContext() const; int32_t RightContext() const; }; - static_assert(sizeof(Nnet) == 0x110); } // namespace snowboy \ No newline at end of file diff --git a/lib/nnet-stream.h b/lib/nnet-stream.h index a568e4f..2480af7 100644 --- a/lib/nnet-stream.h +++ b/lib/nnet-stream.h @@ -11,7 +11,6 @@ namespace snowboy { bool pad_context; void Register(const std::string&, OptionsItf*); }; - static_assert(sizeof(NnetStreamOptions) == 0x10); struct NnetStream : StreamItf { NnetStreamOptions m_options; std::unique_ptr m_nnet; @@ -22,5 +21,4 @@ namespace snowboy { virtual std::string Name() const override; virtual ~NnetStream(); }; - static_assert(sizeof(NnetStream) == 0x30); } // namespace snowboy \ No newline at end of file diff --git a/lib/pipeline-detect.cpp b/lib/pipeline-detect.cpp index 9e3bc7e..9c48aa8 100644 --- a/lib/pipeline-detect.cpp +++ b/lib/pipeline-detect.cpp @@ -308,9 +308,9 @@ namespace snowboy { auto num_personal = m_templateDetectStream == nullptr ? 0 : m_templateDetectStream->field_x40.size(); auto num_universal = 0; if (m_universalDetectStream != nullptr - && !m_universalDetectStream->field_xd0.empty() - && !m_universalDetectStream->field_xd0.back().empty()) { - num_universal = m_universalDetectStream->field_xd0.back().back(); + && !m_universalDetectStream->m_model_info.empty() + && !m_universalDetectStream->m_model_info.back().keywords.empty()) { + num_universal = m_universalDetectStream->m_model_info.back().keywords.back().hotword_id; } if (num_universal + num_personal < parts.size()) { SNOWBOY_ERROR() << "number of hotwords and number of sensitivities mismatch, expecting sensitivities for " @@ -376,8 +376,8 @@ namespace snowboy { } int num_hotwords = 0; if (m_templateDetectStream) num_hotwords += m_templateDetectStream->field_x40.size(); - if (m_universalDetectStream && !m_universalDetectStream->field_xd0.empty() && !m_universalDetectStream->field_xd0.back().empty()) - num_hotwords += m_universalDetectStream->field_xd0.back().back(); + if (m_universalDetectStream && !m_universalDetectStream->m_model_info.empty() && !m_universalDetectStream->m_model_info.back().keywords.empty()) + num_hotwords += m_universalDetectStream->m_model_info.back().keywords.back().hotword_id; return num_hotwords; } diff --git a/lib/pipeline-detect.h b/lib/pipeline-detect.h index d109c5a..2e3ccb8 100644 --- a/lib/pipeline-detect.h +++ b/lib/pipeline-detect.h @@ -40,7 +40,6 @@ namespace snowboy { // Padding void Register(const std::string&, OptionsItf*); }; - static_assert(sizeof(PipelineDetectOptions) == 8); struct PipelineDetect : PipelineItf { // Virtual stuff @@ -108,13 +107,5 @@ namespace snowboy { bool field_x168 = false; bool field_x169 = false; - char data2[6]; }; - static_assert(sizeof(PipelineDetect) == 368); - static_assert(sizeof(PipelineDetect::m_eavesdropStreamFrameInfoVector) == 24); -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Winvalid-offsetof" - static_assert(offsetof(PipelineDetect, m_universalDetectStreamOptions) == 0xf0); - static_assert(offsetof(PipelineDetect, m_eavesdropStreamFrameInfoVector) == 0xf8); -#pragma GCC diagnostic pop } // namespace snowboy \ No newline at end of file diff --git a/lib/pipeline-personal-enroll.h b/lib/pipeline-personal-enroll.h index 76a1b7f..e7213a1 100644 --- a/lib/pipeline-personal-enroll.h +++ b/lib/pipeline-personal-enroll.h @@ -54,5 +54,4 @@ namespace snowboy { void SetModelFilename(const std::string& filename); int GetNumTemplates() const; }; - static_assert(sizeof(PipelinePersonalEnroll) == 0x70); } // namespace snowboy \ No newline at end of file diff --git a/lib/pipeline-template-cut.h b/lib/pipeline-template-cut.h index 524deea..b082018 100644 --- a/lib/pipeline-template-cut.h +++ b/lib/pipeline-template-cut.h @@ -28,7 +28,6 @@ namespace snowboy { void Register(const std::string& prefix, OptionsItf* opts); }; - static_assert(sizeof(PipelineTemplateCutOptions) == 0x10); struct PipelineTemplateCut : PipelineItf { std::unique_ptr m_interceptStream; std::unique_ptr m_framerStream; @@ -57,5 +56,4 @@ namespace snowboy { int CutTemplate(const MatrixBase& in, Matrix* out); void ComputeTemplateBoundary(const MatrixBase&, const std::vector&, int*, int*) const; }; - static_assert(sizeof(PipelineTemplateCut) == 0xa0); } // namespace snowboy \ No newline at end of file diff --git a/lib/pipeline-vad.cpp b/lib/pipeline-vad.cpp index 75c6ed9..a693fa2 100644 --- a/lib/pipeline-vad.cpp +++ b/lib/pipeline-vad.cpp @@ -87,8 +87,7 @@ namespace snowboy { m_eavesdropStream->Connect(m_rawNnetVadStream.get()); m_vadStateStream2->Connect(m_eavesdropStream.get()); m_vadStateStream->field_x2c = 1; - m_vadStateStream->field_x2c = 2; - // TODO: Options are cleaned here in the original + m_vadStateStream2->field_x2c = 2; m_isInitialized = true; return true; } diff --git a/lib/pipeline-vad.h b/lib/pipeline-vad.h index d08713b..d24fe5e 100644 --- a/lib/pipeline-vad.h +++ b/lib/pipeline-vad.h @@ -74,5 +74,4 @@ namespace snowboy { void SetAudioGain(float gain); void SetMaxAudioAmplitude(float maxAmplitude); }; - static_assert(sizeof(PipelineVad) == 0xd8); } // namespace snowboy \ No newline at end of file diff --git a/lib/raw-energy-vad-stream.h b/lib/raw-energy-vad-stream.h index a88aeda..391036b 100644 --- a/lib/raw-energy-vad-stream.h +++ b/lib/raw-energy-vad-stream.h @@ -15,7 +15,6 @@ namespace snowboy { int raw_buffer_extra; void Register(const std::string&, OptionsItf*); }; - static_assert(sizeof(RawEnergyVadStreamOptions) == 20); struct RawEnergyVadStream : StreamItf { RawEnergyVadStreamOptions m_options; bool field_x2c; @@ -35,5 +34,4 @@ namespace snowboy { void InitRawEnergyVad(Matrix*, std::vector*); void UpdateBackgroundEnergy(const std::vector&); }; - static_assert(sizeof(RawEnergyVadStream) == 0x108); } // namespace snowboy \ No newline at end of file diff --git a/lib/raw-nnet-vad-stream.cpp b/lib/raw-nnet-vad-stream.cpp index 5cd59f4..d596239 100644 --- a/lib/raw-nnet-vad-stream.cpp +++ b/lib/raw-nnet-vad-stream.cpp @@ -26,7 +26,6 @@ namespace snowboy { SNOWBOY_ERROR() << "index " << m_options.non_voice_index << " for non-voice label runs out of range (0 - " << dims << "), wrong index?"; return; } - m_fieldx30.Resize(0, 0); // TODO: Useless call } int RawNnetVadStream::Read(Matrix* mat, std::vector* info) { diff --git a/lib/raw-nnet-vad-stream.h b/lib/raw-nnet-vad-stream.h index 74107ac..1454d42 100644 --- a/lib/raw-nnet-vad-stream.h +++ b/lib/raw-nnet-vad-stream.h @@ -15,7 +15,6 @@ namespace snowboy { std::string model_filename; void Register(const std::string&, OptionsItf*); }; - static_assert(sizeof(RawNnetVadStreamOptions) == 0x10); struct RawNnetVadStream : StreamItf { RawNnetVadStreamOptions m_options; std::unique_ptr m_nnet; @@ -28,5 +27,4 @@ namespace snowboy { virtual std::string Name() const override; virtual ~RawNnetVadStream(); }; - static_assert(sizeof(RawNnetVadStream) == 0x48); } // namespace snowboy \ No newline at end of file diff --git a/lib/snowboy-debug.cpp b/lib/snowboy-debug.cpp index 762cdce..061be98 100644 --- a/lib/snowboy-debug.cpp +++ b/lib/snowboy-debug.cpp @@ -23,11 +23,11 @@ namespace snowboy { } void SnowboyAssertFailure(int line, const std::string& file, const std::string& func, const std::string& cond) { - snowboy::MySnowboyLogMsg msg{line, file, func, snowboy::SnowboyLogType::ASSERT_FAIL, 0}; + snowboy::SnowboyLogMsg msg{line, file, func, snowboy::SnowboyLogType::ASSERT_FAIL, 0}; msg << cond; } - MySnowboyLogMsg::MySnowboyLogMsg(int line, const std::string& file, const std::string& function, const SnowboyLogType& type, int) + SnowboyLogMsg::SnowboyLogMsg(int line, const std::string& file, const std::string& function, const SnowboyLogType& type, int) : m_type{type}, m_stream{} { switch (type) { @@ -50,13 +50,17 @@ namespace snowboy { m_stream << function << "():" << file << ":" << line << ") "; } - MySnowboyLogMsg::~MySnowboyLogMsg() noexcept(false) { - std::cout << m_stream.str() << std::endl; - if (m_type == SnowboyLogType::ERROR || m_type == SnowboyLogType::ASSERT_FAIL) + SnowboyLogMsg::~SnowboyLogMsg() noexcept(false) { + std::cerr << m_stream.str() << std::endl; + if (m_type == SnowboyLogType::ERROR) { m_stream << GetStackTrace(); // TODO: This isnt normally allowed.... throw std::runtime_error(m_stream.str()); + } else if (m_type == SnowboyLogType::ERROR) + { + std::cerr << GetStackTrace(); + std::abort(); } } } // namespace snowboy \ No newline at end of file diff --git a/lib/snowboy-debug.h b/lib/snowboy-debug.h index 38f5017..e27efb1 100644 --- a/lib/snowboy-debug.h +++ b/lib/snowboy-debug.h @@ -12,29 +12,28 @@ namespace snowboy { LOG, VLOG }; - struct MySnowboyLogMsg { + struct SnowboyLogMsg { SnowboyLogType m_type; std::stringstream m_stream; - MySnowboyLogMsg(int line, const std::string&, const std::string&, const SnowboyLogType& type, int); - ~MySnowboyLogMsg() noexcept(false); + SnowboyLogMsg(int line, const std::string&, const std::string&, const SnowboyLogType& type, int); + ~SnowboyLogMsg() noexcept(false); template - MySnowboyLogMsg& operator<<(T&& val) { + SnowboyLogMsg& operator<<(T&& val) { m_stream << val; return *this; } }; - using SnowboyLogMsg = MySnowboyLogMsg; } // namespace snowboy #define SNOWBOY_ERROR() \ - snowboy::MySnowboyLogMsg { __LINE__, __FILE__, __FUNCTION__, snowboy::SnowboyLogType::ERROR, 0 } + snowboy::SnowboyLogMsg { __LINE__, __FILE__, __FUNCTION__, snowboy::SnowboyLogType::ERROR, 0 } #define SNOWBOY_WARNING() \ - snowboy::MySnowboyLogMsg { __LINE__, __FILE__, __FUNCTION__, snowboy::SnowboyLogType::WARNING, 0 } + snowboy::SnowboyLogMsg { __LINE__, __FILE__, __FUNCTION__, snowboy::SnowboyLogType::WARNING, 0 } #define SNOWBOY_LOG() \ - snowboy::MySnowboyLogMsg { __LINE__, __FILE__, __FUNCTION__, snowboy::SnowboyLogType::LOG, 0 } + snowboy::SnowboyLogMsg { __LINE__, __FILE__, __FUNCTION__, snowboy::SnowboyLogType::LOG, 0 } #define SNOWBOY_VLOG() \ - snowboy::MySnowboyLogMsg { __LINE__, __FILE__, __FUNCTION__, snowboy::SnowboyLogType::VLOG, 0 } + snowboy::SnowboyLogMsg { __LINE__, __FILE__, __FUNCTION__, snowboy::SnowboyLogType::VLOG, 0 } #ifndef NDEBUG #define SNOWBOY_ASSERT(cond) \ diff --git a/lib/snowboy-detect-c.cpp b/lib/snowboy-detect-c.cpp new file mode 100644 index 0000000..9a9bc9d --- /dev/null +++ b/lib/snowboy-detect-c.cpp @@ -0,0 +1,673 @@ +#include +#include +#include +#include + +extern "C" +{ + void SNOWMAN_free(void* ptr) { + free(ptr); + } + + struct SNOWMAN_Detect : snowboy::SnowboyDetect { + using SnowboyDetect::SnowboyDetect; + }; + + struct SNOWMAN_Vad : snowboy::SnowboyVad { + using SnowboyVad::SnowboyVad; + }; + + struct SNOWMAN_PersonalEnroll : snowboy::SnowboyPersonalEnroll { + using SnowboyPersonalEnroll::SnowboyPersonalEnroll; + }; + + struct SNOWMAN_TemplateCut : snowboy::SnowboyTemplateCut { + using SnowboyTemplateCut::SnowboyTemplateCut; + }; + + SNOWMAN_Detect* SNOWMAN_Detect_Create(const char* resource_filename, const char* model_str) { + if (resource_filename == nullptr) resource_filename = "common.res"; + if (model_str == nullptr) model_str = "model.umdl"; + try { + return new SNOWMAN_Detect{resource_filename, model_str}; + } catch (...) { + errno = EIO; + return nullptr; + } + } + + int SNOWMAN_Detect_Reset(SNOWMAN_Detect* instance) { + if (instance == nullptr) { + errno = EINVAL; + return -1; + } + try { + return instance->Reset() ? 1 : 0; + } catch (...) { + errno = EIO; + return -1; + } + } + + int SNOWMAN_Detect_RunDetectionWave(SNOWMAN_Detect* instance, const void* data, unsigned int len, int is_end) { + if (instance == nullptr) { + errno = EINVAL; + return -1; + } + if (data == nullptr || len == 0) return 0; + try { + std::string temp{reinterpret_cast(data), len}; + return instance->RunDetection(temp, is_end != 0); + } catch (...) { + errno = EIO; + return -1; + } + } + + int SNOWMAN_Detect_RunDetectionFloat(SNOWMAN_Detect* instance, const float* data, unsigned int num_samples, int is_end) { + if (instance == nullptr) { + errno = EINVAL; + return -1; + } + if (data == nullptr || num_samples == 0) return 0; + try { + return instance->RunDetection(data, num_samples, is_end != 0); + } catch (...) { + errno = EIO; + return -1; + } + } + + int SNOWMAN_Detect_RunDetectionShort(SNOWMAN_Detect* instance, const short* data, unsigned int num_samples, int is_end) { + if (instance == nullptr) { + errno = EINVAL; + return -1; + } + if (data == nullptr || num_samples == 0) return 0; + try { + return instance->RunDetection(data, num_samples, is_end != 0); + } catch (...) { + errno = EIO; + return -1; + } + } + + int SNOWMAN_Detect_RunDetectionInt(SNOWMAN_Detect* instance, const int* data, unsigned int num_samples, int is_end) { + if (instance == nullptr) { + errno = EINVAL; + return -1; + } + if (data == nullptr || num_samples == 0) return 0; + try { + return instance->RunDetection(data, num_samples, is_end != 0); + } catch (...) { + errno = EIO; + return -1; + } + } + + int SNOWMAN_Detect_SetSensitivity(SNOWMAN_Detect* instance, const char* sensitivity) { + if (instance == nullptr || sensitivity == nullptr) { + errno = EINVAL; + return -1; + } + try { + instance->SetSensitivity(sensitivity); + return 0; + } catch (...) { + errno = EIO; + return -1; + } + } + + int SNOWMAN_Detect_SetHighSensitivity(SNOWMAN_Detect* instance, const char* sensitivity) { + if (instance == nullptr || sensitivity == nullptr) { + errno = EINVAL; + return -1; + } + try { + instance->SetHighSensitivity(sensitivity); + return 0; + } catch (...) { + errno = EIO; + return -1; + } + } + + int SNOWMAN_Detect_GetSensitivity(SNOWMAN_Detect* instance, char** pointer) { + if (pointer == nullptr) return 0; + if (instance == nullptr || *pointer != nullptr) { + errno = EINVAL; + return -1; + } + try { + auto s = instance->GetSensitivity(); + *pointer = static_cast(malloc(s.size() + 1)); + if (*pointer == nullptr) { + errno = ENOMEM; + return -1; + } + strcpy(*pointer, s.c_str()); + (*pointer)[s.size()] = '\0'; + return 0; + } catch (...) { + errno = EIO; + return -1; + } + } + + int SNOWMAN_Detect_SetAudioGain(SNOWMAN_Detect* instance, float gain) { + if (instance == nullptr || gain <= 0) { + errno = EINVAL; + return -1; + } + try { + instance->SetAudioGain(gain); + return 0; + } catch (...) { + errno = EIO; + return -1; + } + } + + int SNOWMAN_Detect_UpdateModel(SNOWMAN_Detect* instance) { + if (instance == nullptr) { + errno = EINVAL; + return -1; + } + try { + instance->UpdateModel(); + return 0; + } catch (...) { + errno = EIO; + return -1; + } + } + + int SNOWMAN_Detect_NumHotwords(SNOWMAN_Detect* instance) { + if (instance == nullptr) { + errno = EINVAL; + return -1; + } + try { + return instance->NumHotwords(); + } catch (...) { + errno = EIO; + return -1; + } + } + + int SNOWMAN_Detect_ApplyFrontend(SNOWMAN_Detect* instance, int apply) { + if (instance == nullptr) { + errno = EINVAL; + return -1; + } + try { + instance->ApplyFrontend(apply != 0); + return 0; + } catch (...) { + errno = EIO; + return -1; + } + } + + int SNOWMAN_Detect_SampleRate(SNOWMAN_Detect* instance) { + if (instance == nullptr) { + errno = EINVAL; + return -1; + } + try { + return instance->SampleRate(); + } catch (...) { + errno = EIO; + return -1; + } + } + + int SNOWMAN_Detect_NumChannels(SNOWMAN_Detect* instance) { + if (instance == nullptr) { + errno = EINVAL; + return -1; + } + try { + return instance->NumChannels(); + } catch (...) { + errno = EIO; + return -1; + } + } + + int SNOWMAN_Detect_BitsPerSample(SNOWMAN_Detect* instance) { + if (instance == nullptr) { + errno = EINVAL; + return -1; + } + try { + return instance->BitsPerSample(); + } catch (...) { + errno = EIO; + return -1; + } + } + + void SNOWMAN_Detect_Destroy(SNOWMAN_Detect* instance) { + if (instance == nullptr) { + errno = EINVAL; + return; + } + try { + delete instance; + } catch (...) { + // Should never happen, but better be safe than sorry + errno = EIO; + return; + } + } + + SNOWMAN_Vad* SNOWMAN_Vad_Create(const char* resource_filename) { + if (resource_filename == nullptr) resource_filename = "common.res"; + try { + return new SNOWMAN_Vad{resource_filename}; + } catch (...) { + errno = EIO; + return nullptr; + } + } + + int SNOWMAN_Vad_Reset(SNOWMAN_Vad* instance) { + if (instance == nullptr) { + errno = EINVAL; + return -1; + } + try { + return instance->Reset() ? 1 : 0; + } catch (...) { + errno = EIO; + return -1; + } + } + + int SNOWMAN_Vad_RunVadWave(SNOWMAN_Vad* instance, const void* data, unsigned int len, int is_end) { + if (instance == nullptr) { + errno = EINVAL; + return -1; + } + if (data == nullptr || len == 0) return 0; + try { + std::string temp{reinterpret_cast(data), len}; + return instance->RunVad(temp, is_end != 0); + } catch (...) { + errno = EIO; + return -1; + } + } + + int SNOWMAN_Vad_RunVadFloat(SNOWMAN_Vad* instance, const float* data, unsigned int num_samples, int is_end) { + if (instance == nullptr) { + errno = EINVAL; + return -1; + } + if (data == nullptr || num_samples == 0) return 0; + try { + return instance->RunVad(data, num_samples, is_end != 0); + } catch (...) { + errno = EIO; + return -1; + } + } + + int SNOWMAN_Vad_RunVadShort(SNOWMAN_Vad* instance, const short* data, unsigned int num_samples, int is_end) { + if (instance == nullptr) { + errno = EINVAL; + return -1; + } + if (data == nullptr || num_samples == 0) return 0; + try { + return instance->RunVad(data, num_samples, is_end != 0); + } catch (...) { + errno = EIO; + return -1; + } + } + + int SNOWMAN_Vad_RunVadInt(SNOWMAN_Vad* instance, const int* data, unsigned int num_samples, int is_end) { + if (instance == nullptr) { + errno = EINVAL; + return -1; + } + if (data == nullptr || num_samples == 0) return 0; + try { + return instance->RunVad(data, num_samples, is_end != 0); + } catch (...) { + errno = EIO; + return -1; + } + } + + int SNOWMAN_Vad_SetAudioGain(SNOWMAN_Vad* instance, float gain) { + if (instance == nullptr || gain <= 0) { + errno = EINVAL; + return -1; + } + try { + instance->SetAudioGain(gain); + return 0; + } catch (...) { + errno = EIO; + return -1; + } + } + + int SNOWMAN_Vad_ApplyFrontend(SNOWMAN_Vad* instance, int apply) { + if (instance == nullptr) { + errno = EINVAL; + return -1; + } + try { + instance->ApplyFrontend(apply != 0); + return 0; + } catch (...) { + errno = EIO; + return -1; + } + } + + int SNOWMAN_Vad_SampleRate(SNOWMAN_Vad* instance) { + if (instance == nullptr) { + errno = EINVAL; + return -1; + } + try { + return instance->SampleRate(); + } catch (...) { + errno = EIO; + return -1; + } + } + + int SNOWMAN_Vad_NumChannels(SNOWMAN_Vad* instance) { + if (instance == nullptr) { + errno = EINVAL; + return -1; + } + try { + return instance->NumChannels(); + } catch (...) { + errno = EIO; + return -1; + } + } + + int SNOWMAN_Vad_BitsPerSample(SNOWMAN_Vad* instance) { + if (instance == nullptr) { + errno = EINVAL; + return -1; + } + try { + return instance->BitsPerSample(); + } catch (...) { + errno = EIO; + return -1; + } + } + + void SNOWMAN_Vad_Destroy(SNOWMAN_Vad* instance) { + if (instance == nullptr) { + errno = EINVAL; + return; + } + try { + delete instance; + } catch (...) { + // Should never happen, but better be safe than sorry + errno = EIO; + return; + } + } + + SNOWMAN_PersonalEnroll* SNOWMAN_PersonalEnroll_Create(const char* resource_filename, const char* model_str) { + if (resource_filename == nullptr) resource_filename = "common.res"; + if (model_str == nullptr) model_str = "model.pmdl"; + try { + return new SNOWMAN_PersonalEnroll{resource_filename, model_str}; + } catch (...) { + errno = EIO; + return nullptr; + } + } + + int SNOWMAN_PersonalEnroll_Reset(SNOWMAN_PersonalEnroll* instance) { + if (instance == nullptr) { + errno = EINVAL; + return -1; + } + try { + return instance->Reset() ? 1 : 0; + } catch (...) { + errno = EIO; + return -1; + } + } + + int SNOWMAN_PersonalEnroll_RunEnrollmentWave(SNOWMAN_PersonalEnroll* instance, const void* data, unsigned int len) { + if (instance == nullptr) { + errno = EINVAL; + return -1; + } + if (data == nullptr || len == 0) return 0; + try { + std::string temp{reinterpret_cast(data), len}; + return instance->RunEnrollment(temp); + } catch (...) { + errno = EIO; + return -1; + } + } + + int SNOWMAN_PersonalEnroll_RunEnrollmentFloat(SNOWMAN_PersonalEnroll* instance, const float* data, unsigned int num_samples) { + if (instance == nullptr) { + errno = EINVAL; + return -1; + } + if (data == nullptr || num_samples == 0) return 0; + try { + return instance->RunEnrollment(data, num_samples); + } catch (...) { + errno = EIO; + return -1; + } + } + + int SNOWMAN_PersonalEnroll_RunEnrollmentShort(SNOWMAN_PersonalEnroll* instance, const short* data, unsigned int num_samples) { + if (instance == nullptr) { + errno = EINVAL; + return -1; + } + if (data == nullptr || num_samples == 0) return 0; + try { + return instance->RunEnrollment(data, num_samples); + } catch (...) { + errno = EIO; + return -1; + } + } + + int SNOWMAN_PersonalEnroll_RunEnrollmentInt(SNOWMAN_PersonalEnroll* instance, const int* data, unsigned int num_samples) { + if (instance == nullptr) { + errno = EINVAL; + return -1; + } + if (data == nullptr || num_samples == 0) return 0; + try { + return instance->RunEnrollment(data, num_samples); + } catch (...) { + errno = EIO; + return -1; + } + } + + int SNOWMAN_PersonalEnroll_GetNumTemplates(SNOWMAN_PersonalEnroll* instance) { + if (instance == nullptr) { + errno = EINVAL; + return -1; + } + try { + return instance->GetNumTemplates(); + } catch (...) { + errno = EIO; + return -1; + } + } + + int SNOWMAN_PersonalEnroll_SampleRate(SNOWMAN_PersonalEnroll* instance) { + if (instance == nullptr) { + errno = EINVAL; + return -1; + } + try { + return instance->SampleRate(); + } catch (...) { + errno = EIO; + return -1; + } + } + + int SNOWMAN_PersonalEnroll_NumChannels(SNOWMAN_PersonalEnroll* instance) { + if (instance == nullptr) { + errno = EINVAL; + return -1; + } + try { + return instance->NumChannels(); + } catch (...) { + errno = EIO; + return -1; + } + } + + int SNOWMAN_PersonalEnroll_BitsPerSample(SNOWMAN_PersonalEnroll* instance) { + if (instance == nullptr) { + errno = EINVAL; + return -1; + } + try { + return instance->BitsPerSample(); + } catch (...) { + errno = EIO; + return -1; + } + } + + void SNOWMAN_PersonalEnroll_Destroy(SNOWMAN_PersonalEnroll* instance) { + if (instance == nullptr) { + errno = EINVAL; + return; + } + try { + delete instance; + } catch (...) { + // Should never happen, but better be safe than sorry + errno = EIO; + return; + } + } + + SNOWMAN_TemplateCut* SNOWMAN_TemplateCut_Create(const char* resource_filename) { + if (resource_filename == nullptr) resource_filename = "common.res"; + try { + return new SNOWMAN_TemplateCut{resource_filename}; + } catch (...) { + errno = EIO; + return nullptr; + } + } + + int SNOWMAN_TemplateCut_Reset(SNOWMAN_TemplateCut* instance) { + if (instance == nullptr) { + errno = EINVAL; + return -1; + } + try { + return instance->Reset() ? 1 : 0; + } catch (...) { + errno = EIO; + return -1; + } + } + + int SNOWMAN_TemplateCut_CutTemplateWave(SNOWMAN_TemplateCut* instance, const void* indata, unsigned int inlen, void** outdata, unsigned int* outlen) { + if (indata == nullptr || inlen == 0 || outdata == nullptr || outlen == 0) return 0; + if (instance == nullptr || *outdata != nullptr || *outlen != 0) { + errno = EINVAL; + return -1; + } + try { + std::string temp{reinterpret_cast(indata), inlen}; + auto res = instance->CutTemplate(temp); + *outdata = static_cast(malloc(res.size())); + if (*outdata == nullptr) { + errno = ENOMEM; + return -1; + } + memcpy(*outdata, res.data(), res.size()); + *outlen = res.size(); + return 0; + } catch (...) { + errno = EIO; + return -1; + } + } + + int SNOWMAN_TemplateCut_SampleRate(SNOWMAN_TemplateCut* instance) { + if (instance == nullptr) { + errno = EINVAL; + return -1; + } + try { + return instance->SampleRate(); + } catch (...) { + errno = EIO; + return -1; + } + } + + int SNOWMAN_TemplateCut_NumChannels(SNOWMAN_TemplateCut* instance) { + if (instance == nullptr) { + errno = EINVAL; + return -1; + } + try { + return instance->NumChannels(); + } catch (...) { + errno = EIO; + return -1; + } + } + + int SNOWMAN_TemplateCut_BitsPerSample(SNOWMAN_TemplateCut* instance) { + if (instance == nullptr) { + errno = EINVAL; + return -1; + } + try { + return instance->BitsPerSample(); + } catch (...) { + errno = EIO; + return -1; + } + } + + void SNOWMAN_TemplateCut_Destroy(SNOWMAN_TemplateCut* instance) { + if (instance == nullptr) { + errno = EINVAL; + return; + } + try { + delete instance; + } catch (...) { + // Should never happen, but better be safe than sorry + errno = EIO; + return; + } + } +} diff --git a/lib/snowboy-detect-c.h b/lib/snowboy-detect-c.h new file mode 100644 index 0000000..a3f9df2 --- /dev/null +++ b/lib/snowboy-detect-c.h @@ -0,0 +1,69 @@ +#pragma once +#ifdef __cplusplus +extern "C" +{ +#endif + + extern void SNOWMAN_free(void*); + + struct SNOWMAN_Detect; + struct SNOWMAN_Vad; + struct SNOWMAN_PersonalEnroll; + struct SNOWMAN_TemplateCut; + + SNOWMAN_Detect* SNOWMAN_Detect_Create(const char* resource_filename, const char* model_str); + int SNOWMAN_Detect_Reset(SNOWMAN_Detect* instance); + int SNOWMAN_Detect_RunDetectionWave(SNOWMAN_Detect* instance, const void* data, unsigned int len, int is_end); + int SNOWMAN_Detect_RunDetectionFloat(SNOWMAN_Detect* instance, const float* data, unsigned int num_samples, int is_end); + int SNOWMAN_Detect_RunDetectionShort(SNOWMAN_Detect* instance, const short* data, unsigned int num_samples, int is_end); + int SNOWMAN_Detect_RunDetectionInt(SNOWMAN_Detect* instance, const int* data, unsigned int num_samples, int is_end); + int SNOWMAN_Detect_SetSensitivity(SNOWMAN_Detect* instance, const char* sensitivity); + int SNOWMAN_Detect_SetHighSensitivity(SNOWMAN_Detect* instance, const char* sensitivity); + // Returned pointer needs to get freed using SNOWMAN_free + int SNOWMAN_Detect_GetSensitivity(SNOWMAN_Detect* instance, char** pointer); + int SNOWMAN_Detect_SetAudioGain(SNOWMAN_Detect* instance, float gain); + int SNOWMAN_Detect_UpdateModel(SNOWMAN_Detect* instance); + int SNOWMAN_Detect_NumHotwords(SNOWMAN_Detect* instance); + int SNOWMAN_Detect_ApplyFrontend(SNOWMAN_Detect* instance, int apply); + int SNOWMAN_Detect_SampleRate(SNOWMAN_Detect* instance); + int SNOWMAN_Detect_NumChannels(SNOWMAN_Detect* instance); + int SNOWMAN_Detect_BitsPerSample(SNOWMAN_Detect* instance); + void SNOWMAN_Detect_Destroy(SNOWMAN_Detect* instance); + + SNOWMAN_Vad* SNOWMAN_Vad_Create(const char* resource_filename); + int SNOWMAN_Vad_Reset(SNOWMAN_Vad* instance); + int SNOWMAN_Vad_RunVadWave(SNOWMAN_Vad* instance, const void* data, unsigned int len, int is_end); + int SNOWMAN_Vad_RunVadFloat(SNOWMAN_Vad* instance, const float* data, unsigned int num_samples, int is_end); + int SNOWMAN_Vad_RunVadShort(SNOWMAN_Vad* instance, const short* data, unsigned int num_samples, int is_end); + int SNOWMAN_Vad_RunVadInt(SNOWMAN_Vad* instance, const int* data, unsigned int num_samples, int is_end); + int SNOWMAN_Vad_SetAudioGain(SNOWMAN_Vad* instance, float gain); + int SNOWMAN_Vad_ApplyFrontend(SNOWMAN_Vad* instance, int apply); + int SNOWMAN_Vad_SampleRate(SNOWMAN_Vad* instance); + int SNOWMAN_Vad_NumChannels(SNOWMAN_Vad* instance); + int SNOWMAN_Vad_BitsPerSample(SNOWMAN_Vad* instance); + void SNOWMAN_Vad_Destroy(SNOWMAN_Vad* instance); + + SNOWMAN_PersonalEnroll* SNOWMAN_PersonalEnroll_Create(const char* resource_filename, const char* model_str); + int SNOWMAN_PersonalEnroll_Reset(SNOWMAN_PersonalEnroll* instance); + int SNOWMAN_PersonalEnroll_RunEnrollmentWave(SNOWMAN_PersonalEnroll* instance, const void* data, unsigned int len); + int SNOWMAN_PersonalEnroll_RunEnrollmentFloat(SNOWMAN_PersonalEnroll* instance, const float* data, unsigned int num_samples); + int SNOWMAN_PersonalEnroll_RunEnrollmentShort(SNOWMAN_PersonalEnroll* instance, const short* data, unsigned int num_samples); + int SNOWMAN_PersonalEnroll_RunEnrollmentInt(SNOWMAN_PersonalEnroll* instance, const int* data, unsigned int num_samples); + int SNOWMAN_PersonalEnroll_GetNumTemplates(SNOWMAN_PersonalEnroll* instance); + int SNOWMAN_PersonalEnroll_SampleRate(SNOWMAN_PersonalEnroll* instance); + int SNOWMAN_PersonalEnroll_NumChannels(SNOWMAN_PersonalEnroll* instance); + int SNOWMAN_PersonalEnroll_BitsPerSample(SNOWMAN_PersonalEnroll* instance); + void SNOWMAN_PersonalEnroll_Destroy(SNOWMAN_PersonalEnroll* instance); + + SNOWMAN_TemplateCut* SNOWMAN_TemplateCut_Create(const char* resource_filename); + int SNOWMAN_TemplateCut_Reset(SNOWMAN_TemplateCut* instance); + // Returned data needs to get freed using SNOWMAN_free + int SNOWMAN_TemplateCut_CutTemplateWave(SNOWMAN_TemplateCut* instance, const void* indata, unsigned int inlen, void** outdata, unsigned int* outlen); + int SNOWMAN_TemplateCut_SampleRate(SNOWMAN_TemplateCut* instance); + int SNOWMAN_TemplateCut_NumChannels(SNOWMAN_TemplateCut* instance); + int SNOWMAN_TemplateCut_BitsPerSample(SNOWMAN_TemplateCut* instance); + void SNOWMAN_TemplateCut_Destroy(SNOWMAN_TemplateCut* instance); + +#ifdef __cplusplus +} +#endif diff --git a/lib/snowboy-detect.cpp b/lib/snowboy-detect.cpp index cf3e26e..7a1030b 100644 --- a/lib/snowboy-detect.cpp +++ b/lib/snowboy-detect.cpp @@ -7,7 +7,7 @@ #include #include #include -#include +#include namespace snowboy { SnowboyDetect::SnowboyDetect(const std::string& resource_filename, const std::string& model_str) { @@ -35,9 +35,10 @@ namespace snowboy { } int SnowboyDetect::RunDetection(const std::string& data, bool is_end) { - SNOWBOY_ERROR() << "Not implemented"; - return -1; - // TODO: Parse WAVE header and run detection on data block + if ((data.size() % wave_header_->wBlockAlign) != 0) return -1; + Matrix data_mat; + ReadRawWaveFromString(*wave_header_, data, &data_mat); + return detect_pipeline_->RunDetection(data_mat, is_end); } int SnowboyDetect::RunDetection(const float* const data, const int array_length, bool is_end) { @@ -162,9 +163,10 @@ namespace snowboy { } int SnowboyVad::RunVad(const std::string& data, bool is_end) { - SNOWBOY_ERROR() << "Not implemented"; - return -1; - // TODO: Parse WAVE header and run detection on data block + if ((data.size() % wave_header_->wBlockAlign) != 0) return -1; + Matrix data_mat; + ReadRawWaveFromString(*wave_header_, data, &data_mat); + return vad_pipeline_->RunVad(data_mat, is_end); } int SnowboyVad::RunVad(const float* const data, const int array_length, bool is_end) { diff --git a/lib/snowboy-detect.h b/lib/snowboy-detect.h index 852e42b..5a260a1 100644 --- a/lib/snowboy-detect.h +++ b/lib/snowboy-detect.h @@ -218,7 +218,6 @@ namespace snowboy { std::unique_ptr vad_pipeline_; }; - // TODO: This is untested class SnowboyPersonalEnroll { public: SnowboyPersonalEnroll(const std::string& resource_filename, const std::string& model_filename); @@ -276,11 +275,21 @@ namespace snowboy { std::unique_ptr enroll_pipeline_; }; - // TODO: This is untested class SnowboyTemplateCut { public: SnowboyTemplateCut(const std::string& resource_filename); + // Cuts a template. Supported audio format is WAVE (with linear PCM, + // 8-bits unsigned integer, 16-bits signed integer or 32-bits signed integer). + // See SampleRate(), NumChannels() and BitsPerSample() for the required + // sampling rate, number of channels and bits per sample values. You are + // supposed to provide a full recording of the hotword for each call to + // CutTemplate. This method runs runs the provided sample through a Vad Pipeline + // and removes leading and trailing silence. + // + // @param [in] data Small chunk of data to be detected. See + // above for the supported data format. + // @return Cut template in the format provided in data, without a wave header. std::string CutTemplate(const std::string& data); bool Reset(); diff --git a/lib/snowboy-options.cpp b/lib/snowboy-options.cpp index 4981f50..f8b1c15 100644 --- a/lib/snowboy-options.cpp +++ b/lib/snowboy-options.cpp @@ -96,11 +96,8 @@ namespace snowboy { m_opt_print_usage = false; m_usage = usage; Register("", "config", "Configuration file to be read.", &m_opt_config_file); - //TODO: field_0x70.push_back("config"); Register("", "help", "If true, print usage information.", &m_opt_print_usage); - //TODO: field_0x70.push_back("help"); Register("", "verbose", "Verbose level.", &global_snowboy_verbose_level); - //TODO: field_0x70.push_back("verbose"); } ParseOptions::~ParseOptions() {} diff --git a/lib/snowboy-options.h b/lib/snowboy-options.h index c1cbbfb..af4caab 100644 --- a/lib/snowboy-options.h +++ b/lib/snowboy-options.h @@ -25,8 +25,6 @@ namespace snowboy { std::string* m_string_value; }; type m_type; - // TODO: This might be an artifact of std::map, since its unused in OptionInfo - char data[28]; OptionInfo(bool* ptr); OptionInfo(std::string* ptr); @@ -39,7 +37,7 @@ namespace snowboy { void SetValue(const std::string& v); }; - static_assert(sizeof(OptionInfo) == 0x38); + struct OptionsItf { virtual void Register(const std::string& prefix, const std::string& name, const std::string& usage_info, bool* ptr) = 0; virtual void Register(const std::string& prefix, const std::string& name, const std::string& usage_info, int32_t* ptr) = 0; @@ -56,7 +54,6 @@ namespace snowboy { std::string m_usage; std::vector m_arguments; std::unordered_map m_options; - std::unordered_map field_0x70; ParseOptions(const std::string& usage); ~ParseOptions(); @@ -78,15 +75,4 @@ namespace snowboy { void ReadConfigFile(const std::string& filename); void ReadConfigString(const std::string& config); }; -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Winvalid-offsetof" - static_assert(offsetof(ParseOptions, m_opt_print_usage) == 0x8); - static_assert(offsetof(ParseOptions, m_opt_config_file) == 0x10); - static_assert(offsetof(ParseOptions, m_usage) == 0x18); - static_assert(offsetof(ParseOptions, m_arguments) == 0x20); - static_assert(offsetof(ParseOptions, m_options) == 0x38); - static_assert(offsetof(ParseOptions, field_0x70) == 0x70); -#pragma GCC diagnostic pop - // TODO: This might be wrong, since we dont have a new/malloc call for it. - static_assert(sizeof(ParseOptions) == 0xa8); } // namespace snowboy \ No newline at end of file diff --git a/lib/template-container.h b/lib/template-container.h index 2f0b89d..d68d96f 100644 --- a/lib/template-container.h +++ b/lib/template-container.h @@ -19,5 +19,4 @@ namespace snowboy { void Clear(); void AddTemplate(const MatrixBase& tpl); }; - static_assert(sizeof(TemplateContainer) == 0x28); } // namespace snowboy \ No newline at end of file diff --git a/lib/template-detect-stream.h b/lib/template-detect-stream.h index a7de113..2616a4b 100644 --- a/lib/template-detect-stream.h +++ b/lib/template-detect-stream.h @@ -17,7 +17,6 @@ namespace snowboy { SlidingDtwOptions dtw_options; void Register(const std::string&, OptionsItf*); }; - static_assert(sizeof(TemplateDetectStreamOptions) == 0x28); struct TemplateDetectStream : StreamItf { TemplateDetectStreamOptions m_options; std::vector field_x40; @@ -38,5 +37,4 @@ namespace snowboy { size_t NumHotwords(int model_id) const; void UpdateModel() const; }; - static_assert(sizeof(TemplateDetectStream) == 0x98); } // namespace snowboy \ No newline at end of file diff --git a/lib/template-enroll-stream.h b/lib/template-enroll-stream.h index 5726816..1474d24 100644 --- a/lib/template-enroll-stream.h +++ b/lib/template-enroll-stream.h @@ -15,7 +15,6 @@ namespace snowboy { std::string model_filename; void Register(const std::string&, OptionsItf*); }; - static_assert(sizeof(TemplateEnrollStreamOptions) == 0x20); struct TemplateEnrollStream : StreamItf { TemplateEnrollStreamOptions m_options; TemplateContainer field_x38; @@ -30,5 +29,4 @@ namespace snowboy { void SetModelFilename(const std::string& name); }; - static_assert(sizeof(TemplateEnrollStream) == 0x80); } // namespace snowboy \ No newline at end of file diff --git a/lib/universal-detect-stream.cpp b/lib/universal-detect-stream.cpp index 981c3d7..e67259c 100644 --- a/lib/universal-detect-stream.cpp +++ b/lib/universal-detect-stream.cpp @@ -43,8 +43,8 @@ namespace snowboy { if (!m_options.sensitivity_str.empty()) SetHighSensitivity(m_options.sensitivity_str); } else SetHighSensitivity(m_options.high_sensitivity_str); - for (size_t i = 0; i < field_x70.size(); i++) - CheckLicense(i); + for (auto& e : m_model_info) + e.CheckLicense(); field_x60 = false; field_x64 = 0; field_x68 = false; @@ -58,14 +58,14 @@ namespace snowboy { std::vector read_info; auto read_res = m_connectedStream->Read(&read_mat, &read_info); if ((read_res & 0xc2) != 0) return read_res; - for (size_t file = 0; file < field_x70.size(); file++) { + for (size_t file = 0; file < m_model_info.size(); file++) { Matrix nnet_out_mat; std::vector nnet_out_info; if ((read_res & 0x18) == 0) - field_x70[file].Compute(read_mat, read_info, &nnet_out_mat, &nnet_out_info); + m_model_info[file].network.Compute(read_mat, read_info, &nnet_out_mat, &nnet_out_info); else - field_x70[file].FlushOutput(read_mat, read_info, &nnet_out_mat, &nnet_out_info); - SmoothPosterior(file, &nnet_out_mat); + m_model_info[file].network.FlushOutput(read_mat, read_info, &nnet_out_mat, &nnet_out_info); + m_model_info[file].SmoothPosterior(&nnet_out_mat); for (size_t r = 0; r < nnet_out_mat.m_rows; r += m_options.slide_step) { auto max = 0; if (r + m_options.slide_step > nnet_out_mat.m_rows) @@ -76,14 +76,14 @@ namespace snowboy { const auto max_frame_id = nnet_out_info[max - 1].frame_id; float fVar8 = 0.0f; int local_130 = -1; - for (size_t i = 0; i < field_x88[file].size(); i++) { + for (size_t i = 0; i < m_model_info[file].keywords.size(); i++) { auto posterior = GetHotwordPosterior(file, i, max_frame_id); if (!field_x68 || max_frame_id - field_x6c < 0x33) { if (field_x60) { if (3000 < max_frame_id - field_x64) { field_x60 = false; } - if (1.0f - field_xb8[file][i] <= posterior && m_options.min_detection_interval < max_frame_id - field_x58) + if (1.0f - m_model_info[file].keywords[i].high_sensitivity <= posterior && m_options.min_detection_interval < max_frame_id - field_x58) { if (fVar8 < posterior) { local_130 = i; @@ -92,10 +92,10 @@ namespace snowboy { field_x64 = max_frame_id; } } else { - if (posterior < 1.0f - field_xa0[file][i] || max_frame_id - field_x58 <= m_options.min_detection_interval) { + if (posterior < 1.0f - m_model_info[file].keywords[i].sensitivity || max_frame_id - field_x58 <= m_options.min_detection_interval) { if (!field_x68 - && 1.0f - field_xb8[file][i] <= posterior - && posterior < 1.0f - field_xa0[file][i] + && 1.0f - m_model_info[file].keywords[i].high_sensitivity <= posterior + && posterior < 1.0f - m_model_info[file].keywords[i].sensitivity && max_frame_id - field_x58 <= m_options.min_detection_interval) { field_x68 = true; field_x6c = max_frame_id; @@ -105,7 +105,7 @@ namespace snowboy { local_130 = i; fVar8 = posterior; } - if (!field_x68 && field_xa0[file][i] < field_xb8[file][i]) { + if (!field_x68 && m_model_info[file].keywords[i].sensitivity < m_model_info[file].keywords[i].high_sensitivity) { field_x68 = true; field_x6c = max_frame_id; } @@ -115,7 +115,7 @@ namespace snowboy { field_x68 = false; field_x60 = true; field_x64 = max_frame_id; - if (1.0f - field_xb8[file][i] <= posterior && m_options.min_detection_interval < max_frame_id - field_x58) + if (1.0f - m_model_info[file].keywords[i].high_sensitivity <= posterior && m_options.min_detection_interval < max_frame_id - field_x58) { if (fVar8 < posterior) { local_130 = i; @@ -126,12 +126,12 @@ namespace snowboy { } } if (local_130 != -1) { - CheckLicense(file); + m_model_info[file].CheckLicense(); field_x58 = max_frame_id; field_x5c = max_frame_id; ResetDetection(); mat->Resize(1, 1); - mat->m_data[0] = field_xd0[file][local_130]; + mat->m_data[0] = m_model_info[file].keywords[local_130].hotword_id; if (info != nullptr) { auto i = nnet_out_info[r]; info->push_back(i); @@ -147,8 +147,8 @@ namespace snowboy { } bool UniversalDetectStream::Reset() { - for (auto& e : field_x70) - e.ResetComputation(); + for (auto& e : m_model_info) + e.network.ResetComputation(); ResetDetection(); return true; } @@ -159,12 +159,12 @@ namespace snowboy { UniversalDetectStream::~UniversalDetectStream() {} - void UniversalDetectStream::CheckLicense(int param_1) const { - if (field_x1a8[param_1] > 0.0f) { + void UniversalDetectStream::ModelInfo::CheckLicense() const { + if (license_days > 0.0f) { time_t t; time(&t); - auto diff = difftime(t, field_x190[param_1]); - auto expires = field_x1a8[param_1]; + auto diff = difftime(t, license_start); + auto expires = license_days; if (expires < (diff / 86400.0f)) { SNOWBOY_ERROR() << "Your license for Snowboy has been expired. Please contact KITT.AI at snowboy@kitt.ai"; return; @@ -173,8 +173,8 @@ namespace snowboy { } float UniversalDetectStream::GetHotwordPosterior(int param_1, int param_2, int param_3) { - switch (field_xe8[param_1][param_2]) { - case 1: return HotwordNaiveSearch(param_1, param_2); + switch (m_model_info[param_1].keywords[param_2].search_method) { + case 1: return m_model_info[param_1].HotwordNaiveSearch(param_2); case 2: return HotwordDtwSearch(param_1, param_2); case 3: return HotwordViterbiSearch(param_1, param_2); case 4: return HotwordPiecewiseSearch(param_1, param_2); @@ -191,12 +191,12 @@ namespace snowboy { std::string UniversalDetectStream::GetSensitivity() const { std::stringstream res; - for (size_t i = 0; i < field_xa0.size(); i++) { - for (size_t x = 0; x < field_xa0[i].size(); x++) { + for (size_t i = 0; i < m_model_info.size(); i++) { + for (size_t x = 0; x < m_model_info[i].keywords.size(); x++) { if (i != 0 || x != 0) { res << ", "; } - res << field_xa0[i][x]; + res << m_model_info[i].keywords[x].sensitivity; } } return res.str(); @@ -208,14 +208,14 @@ namespace snowboy { // TODO:: This is unused in all models I have, but we should still implement it at some point } - float UniversalDetectStream::HotwordNaiveSearch(int param_1, int param_2) const { + float UniversalDetectStream::ModelInfo::HotwordNaiveSearch(int param_2) const { float sum = 0.0f; - for (size_t i = 0; i < field_x88[param_1][param_2].size(); i++) { - auto& x = field_x250[param_1][field_x88[param_1][param_2][i]]; - if (field_x160[param_1][param_2][i] > x.front()) return 0.0f; + for (size_t i = 0; i < keywords[param_2].field_x88.size(); i++) { + auto& x = field_x250[keywords[param_2].field_x88[i]]; + if (keywords[param_2].search_floor[i] > x.front()) return 0.0f; sum += logf(std::max(x.front(), std::numeric_limits::min())); } - return expf(sum / static_cast(field_x88[param_1][param_2].size())); + return expf(sum / static_cast(keywords[param_2].field_x88.size())); } float UniversalDetectStream::HotwordPiecewiseSearch(int, int) const { @@ -225,24 +225,24 @@ namespace snowboy { } float UniversalDetectStream::HotwordViterbiSearch(int param_1, int param_2) const { - return HotwordNaiveSearch(param_1, param_2); + return m_model_info[param_1].HotwordNaiveSearch(param_2); // TODO: Implement Viterbi search std::vector x; - x.resize(field_x88[param_1][param_2].size(), -std::numeric_limits::max()); + x.resize(m_model_info[param_1].keywords[param_2].field_x88.size(), -std::numeric_limits::max()); x[0] = 0.0f; std::vector x2; - x2.resize(field_x88[param_1][param_2].size(), 0); - auto& f250 = field_x250[param_1][0]; - int i = f250.size() - field_x148[param_1][param_2].back(); + x2.resize(m_model_info[param_1].keywords[param_2].field_x88.size(), 0); + auto& f250 = m_model_info[param_1].field_x250[0]; + int i = f250.size() - m_model_info[param_1].keywords[param_2].search_mask.back(); do { if (f250.size() <= i) { - auto fVar2 = field_x160[param_1][param_2].back(); + auto fVar2 = m_model_info[param_1].keywords[param_2].search_floor.back(); if (fVar2 <= x2.back()) { - if (field_x178[param_1][param_2] && !x.empty()) { + if (m_model_info[param_1].keywords[param_2].search_max && !x.empty()) { for (auto& e : x) { fVar2 = std::max(e, fVar2); } - return fVar2 / static_cast(field_x148[param_1][param_2].back()); + return fVar2 / static_cast(m_model_info[param_1].keywords[param_2].search_mask.back()); } } } @@ -279,14 +279,18 @@ namespace snowboy { float UniversalDetectStream::HotwordViterbiSearchTracebackLog(int param_1, int param_2) const { // TODO: Implement this - return HotwordNaiveSearch(param_1, param_2); + return m_model_info[param_1].HotwordNaiveSearch(param_2); + } + + int UniversalDetectStream::ModelInfo::NumHotwords() const { + return keywords.size(); } int UniversalDetectStream::NumHotwords(int model_id) const { - if (model_id < field_x88.size() && model_id >= 0) { - return field_x88[model_id].size(); + if (model_id < m_model_info.size() && model_id >= 0) { + return m_model_info[model_id].NumHotwords(); } else { - SNOWBOY_ERROR() << "model_id runs out of range, expecting a value between [0," << field_x88.size() << "], got " << model_id << " instead."; + SNOWBOY_ERROR() << "model_id runs out of range, expecting a value between [0," << m_model_info.size() << "], got " << model_id << " instead."; return 0; } } @@ -295,14 +299,100 @@ namespace snowboy { // TODO: Optimize this by calculating offsets and doing a memcpy for (size_t r = 0; r < param_2.m_rows; r++) { for (size_t c = 0; c < param_2.m_cols; c++) { - field_x250[param_1][c].push_back(param_2.m_data[r * param_2.m_stride + c]); - if (field_x250[param_1][c].size() > field_x220.size()) { - field_x250[param_1][c].pop_front(); + m_model_info[param_1].field_x250[c].push_back(param_2.m_data[r * param_2.m_stride + c]); + if (m_model_info[param_1].field_x250[c].size() > m_model_info.size()) { + m_model_info[param_1].field_x250[c].pop_front(); } } } } + void UniversalDetectStream::KeyWordInfo::ReadKeyword(bool binary, std::istream* is, int slide_window) { + ExpectToken(binary, "", is); + ReadIntegerVector(binary, &field_x88, is); + ExpectToken(binary, "", is); + ReadBasicType(binary, &sensitivity, is); + high_sensitivity = 0.0f; + search_floor.resize(field_x88.size()); + // TODO: I think there is a bug here. + // If SearchMax is present, but SearchMethod is not, the code would branch into the first + // if and throw on the ExpectToken. It might be possible that SearchMethod is *required* + // if SearchMax is present, but I dont know for sure. + if (PeekToken(binary, is) == 'S') { + ExpectToken(binary, "", is); + ReadBasicType(binary, &search_method, is); + ExpectToken(binary, "", is); + ReadBasicType(binary, &search_neighbour, is); + ExpectToken(binary, "", is); + ReadIntegerVector(binary, &search_mask, is); + ExpectToken(binary, "", is); + Vector tvec; + tvec.Read(binary, is); + search_floor.resize(tvec.m_size); + for (size_t i = 0; i < tvec.m_size; i++) + search_floor[i] = tvec.m_data[i]; + } else { + search_mask.resize(field_x88.size()); + for (size_t i = 0; i < search_mask.size(); i++) { + search_mask[i] = (static_cast(i) / static_cast(search_mask.size())) * static_cast(slide_window); + } + } + if (PeekToken(binary, is) == 'S') { + ExpectToken(binary, "", is); + ReadBasicType(binary, &search_max, is); + } + if (PeekToken(binary, is) == 'N') { + ExpectToken(binary, "", is); + ReadBasicType(binary, &field_x1d8, is); + } + if (PeekToken(binary, is) == 'D') { + ExpectToken(binary, "", is); + ReadBasicType(binary, &duration_pass, is); + ExpectToken(binary, "", is); + ReadBasicType(binary, &floor_pass, is); + } + } + + void UniversalDetectStream::ModelInfo::ReadHotwordModel(bool binary, std::istream* is, int num_repeats, int* hotword_id) { + ExpectToken(binary, "", is); + if (PeekToken(binary, is) == 'L') { + ExpectToken(binary, "", is); + ReadBasicType(binary, &license_start, is); + ExpectToken(binary, "", is); + ReadBasicType(binary, &license_days, is); + } else { + license_start = 0; + license_days = 0.0f; + } + ExpectToken(binary, "", is); + ExpectToken(binary, "", is); + ReadBasicType(binary, &smooth_window, is); + ExpectToken(binary, "", is); + ReadBasicType(binary, &slide_window, is); + ExpectToken(binary, "", is); + int num_kws; + ReadBasicType(binary, &num_kws, is); + keywords.resize(num_kws); + for (auto& e : keywords) { + e.search_method = 1; + e.field_x1c0 = num_repeats; + e.field_x1d8 = 1; + } + for (size_t kw = 0; kw < num_kws; kw++) { + keywords[kw].ReadKeyword(binary, is, slide_window); + keywords[kw].hotword_id = (*hotword_id)++; + } + ExpectToken(binary, "", is); + network.Read(binary, is); + field_x238.resize(field_x238.size() + network.OutputDim()); + field_x250.resize(field_x250.size() + network.OutputDim()); + field_x268.resize(field_x268.size() + network.OutputDim()); + if (keywords[0].search_method == 4) { + SNOWBOY_ERROR() << "Not implemented!"; + // TODO + } + } + void UniversalDetectStream::ReadHotwordModel(const std::string& filename) { std::vector files; SplitStringToVector(filename, global_snowboy_string_delimiter, &files); @@ -311,150 +401,41 @@ namespace snowboy { return; } auto s = files.size(); - field_x88.resize(s); - field_x70.resize(s); - field_x190.resize(s); - field_x1a8.resize(s); - field_xa0.resize(s); - field_xb8.resize(s); - field_xd0.resize(s); - field_xe8.resize(s); - field_x100.resize(s); - field_x118.resize(s); - field_x130.resize(s); - field_x148.resize(s); - field_x160.resize(s); - field_x178.resize(s); - field_x208.resize(s); - field_x220.resize(s); - field_x238.resize(s); - field_x250.resize(s); - field_x268.resize(s); - field_x1c0.resize(s); - field_x1d8.resize(s); - field_x1f0.resize(s); - field_x280.resize(s); - field_x298.resize(s); - field_x2b0.resize(s); + m_model_info.resize(s); int hotword_id = 1; for (size_t f = 0; f < files.size(); f++) { Input in{files[f]}; auto binary = in.is_binary(); auto is = in.Stream(); - ExpectToken(binary, "", is); - if (PeekToken(binary, is) == 'L') { - ExpectToken(binary, "", is); - ReadBasicType(binary, &field_x190[f], is); - ExpectToken(binary, "", is); - ReadBasicType(binary, &field_x1a8[f], is); - } else { - field_x190[f] = 0; - field_x1a8[f] = 0.0f; - } - ExpectToken(binary, "", is); - ExpectToken(binary, "", is); - ReadBasicType(binary, &field_x208[f], is); - ExpectToken(binary, "", is); - ReadBasicType(binary, &field_x220[f], is); - ExpectToken(binary, "", is); - int num_kws; - ReadBasicType(binary, &num_kws, is); - field_x88[f].resize(num_kws); - field_xa0[f].resize(num_kws); - field_xb8[f].resize(num_kws); - field_xd0[f].resize(num_kws); - field_xe8[f].resize(num_kws, 1); - field_x100[f].resize(num_kws); - field_x118[f].resize(num_kws); - field_x130[f].resize(num_kws); - field_x148[f].resize(num_kws); - field_x160[f].resize(num_kws); - field_x178[f].resize(num_kws); - field_x1c0[f].resize(num_kws, m_options.num_repeats); - field_x1d8[f].resize(num_kws, 1); - field_x280[f].resize(num_kws); - field_x298[f].resize(num_kws); - for (size_t kw = 0; kw < num_kws; kw++) { - ExpectToken(binary, "", is); - ReadIntegerVector(binary, &field_x88[f][kw], is); - ExpectToken(binary, "", is); - ReadBasicType(binary, &field_xa0[f][kw], is); - field_xb8[f][kw] = 0.0f; - field_x160[f][kw].resize(field_x88[f][kw].size()); - // TODO: I think there is a bug here. - // If SearchMax is present, but SearchMethod is not, the code would branch into the first - // if and throw on the ExpectToken. It might be possible that SearchMethod is *required* - // if SearchMax is present, but I dont know for sure. - if (PeekToken(binary, is) == 'S') { - ExpectToken(binary, "", is); - ReadBasicType(binary, &field_xe8[f][kw], is); - ExpectToken(binary, "", is); - ReadBasicType(binary, &field_x100[f][kw], is); - ExpectToken(binary, "", is); - ReadIntegerVector(binary, &field_x148[f][kw], is); - ExpectToken(binary, "", is); - Vector tvec; - tvec.Read(binary, is); - field_x160[f][kw].resize(tvec.m_size); - for (size_t i = 0; i < tvec.m_size; i++) - field_x160[f][kw][i] = tvec.m_data[i]; - } else { - field_x148[f][kw].resize(field_x88[f][kw].size()); - for (size_t i = 0; i < field_x148[f][kw].size(); i++) { - field_x148[f][kw][i] = (static_cast(i) / static_cast(field_x148[f][kw].size())) * static_cast(field_x220[f]); - } - } - if (PeekToken(binary, is) == 'S') { - ExpectToken(binary, "", is); - bool tbool; - ReadBasicType(binary, &tbool, is); - field_x178[f][kw] = tbool; - } - if (PeekToken(binary, is) == 'N') { - ExpectToken(binary, "", is); - ReadBasicType(binary, &field_x1d8[f][kw], is); - } - if (PeekToken(binary, is) == 'D') { - ExpectToken(binary, "", is); - ReadBasicType(binary, &field_x118[f][kw], is); - ExpectToken(binary, "", is); - ReadBasicType(binary, &field_x130[f][kw], is); - } - field_xd0[f][kw] = hotword_id++; - } - ExpectToken(binary, "", is); - field_x70[f].Read(binary, is); - field_x238[f].resize(field_x238[f].size() + field_x70[f].OutputDim()); - field_x250[f].resize(field_x250[f].size() + field_x70[f].OutputDim()); - field_x268[f].resize(field_x268[f].size() + field_x70[f].OutputDim()); - if (field_xe8[f][0] == 4) { - SNOWBOY_ERROR() << "Not implemented!"; - // TODO - } + m_model_info[f].ReadHotwordModel(binary, is, m_options.num_repeats, &hotword_id); + } + } + + void UniversalDetectStream::ModelInfo::ResetDetection() { + for (size_t x = 0; x < field_x238.size(); x++) { + field_x238[x].clear(); + } + for (size_t x = 0; x < field_x250.size(); x++) { + field_x250[x].clear(); + } + for (size_t x = 0; x < field_x268.size(); x++) { + field_x268[x] = 0.0f; + } + for (size_t x = 0; x < keywords.size(); x++) { + keywords[x].field_x280 = false; + } + for (size_t x = 0; x < field_x2b0.size(); x++) { + field_x2b0[x] = 0.0f; + } + for (size_t x = 0; x < keywords.size(); x++) { + keywords[x].field_x298 = -1000; } } void UniversalDetectStream::ResetDetection() { - for (size_t i = 0; i < field_x70.size(); i++) { - for (size_t x = 0; x < field_x238[i].size(); x++) { - field_x238[i][x].clear(); - } - for (size_t x = 0; x < field_x250[i].size(); x++) { - field_x250[i][x].clear(); - } - for (size_t x = 0; x < field_x268[i].size(); x++) { - field_x268[i][x] = 0.0f; - } - for (size_t x = 0; x < field_x280[i].size(); x++) { - field_x280[i][x] = false; - } - for (size_t x = 0; x < field_x2b0[i].size(); x++) { - field_x2b0[i][x] = 0.0f; - } - for (size_t x = 0; x < field_x298[i].size(); x++) { - field_x298[i][x] = -1000; - } + for (auto& e : m_model_info) { + e.ResetDetection(); } } @@ -462,20 +443,20 @@ namespace snowboy { std::vector parts; SplitStringToFloats(param_1, global_snowboy_string_delimiter, &parts); if (parts.size() == 1) { - for (auto& e : field_xb8) { - for (auto& e2 : e) { - e2 = parts[0]; + for (auto& e : m_model_info) { + for (auto& e2 : e.keywords) { + e2.high_sensitivity = parts[0]; } } - } else if (parts.size() == field_xb8.size()) { - for (size_t i = 0; i < field_xb8.size(); i++) { - for (auto& e : field_xb8[i]) { - e = parts[i]; + } else if (parts.size() == m_model_info.size()) { + for (size_t i = 0; i < m_model_info.size(); i++) { + for (auto& e : m_model_info[i].keywords) { + e.high_sensitivity = parts[i]; } } } else { SNOWBOY_ERROR() << "Number of sensitivities does not match number of hotwords (" - << parts.size() << " v.s. " << field_xb8.size() + << parts.size() << " v.s. " << m_model_info.size() << "). Note that each universal model may have multiple hotwords."; return; } @@ -485,104 +466,124 @@ namespace snowboy { std::vector parts; SplitStringToFloats(param_1, global_snowboy_string_delimiter, &parts); if (parts.size() == 1) { - for (auto& e : field_xa0) { - for (auto& e2 : e) { - e2 = parts[0]; + for (auto& e : m_model_info) { + for (auto& e2 : e.keywords) { + e2.sensitivity = parts[0]; } } - } else if (parts.size() == field_xa0.size()) { - for (size_t i = 0; i < field_xa0.size(); i++) { - for (auto& e : field_xa0[i]) { - e = parts[i]; + } else if (parts.size() == m_model_info.size()) { + for (size_t i = 0; i < m_model_info.size(); i++) { + for (auto& e : m_model_info[i].keywords) { + e.sensitivity = parts[i]; } } } else { SNOWBOY_ERROR() << "Number of sensitivities does not match number of hotwords (" - << parts.size() << " v.s. " << field_xa0.size() + << parts.size() << " v.s. " << m_model_info.size() << "). Note that each universal model may have multiple hotwords."; return; } } void UniversalDetectStream::SetSlideWindowSize(const std::string& param_1) { - SplitStringToIntegers(param_1, global_snowboy_string_delimiter, &field_x220); + std::vector parts; + SplitStringToIntegers(param_1, global_snowboy_string_delimiter, &parts); + for (size_t i = 0; i < std::min(m_model_info.size(), parts.size()); i++) { + m_model_info[i].slide_window = parts[i]; + } } void UniversalDetectStream::SetSmoothWindowSize(const std::string& param_1) { - SplitStringToIntegers(param_1, global_snowboy_string_delimiter, &field_x208); + std::vector parts; + SplitStringToIntegers(param_1, global_snowboy_string_delimiter, &parts); + for (size_t i = 0; i < std::min(m_model_info.size(), parts.size()); i++) { + m_model_info[i].smooth_window = parts[i]; + } } - void UniversalDetectStream::SmoothPosterior(int param_1, Matrix* param_2) { + void UniversalDetectStream::ModelInfo::SmoothPosterior(Matrix* param_2) { for (size_t r = 0; r < param_2->m_rows; r++) { for (size_t c = 0; c < param_2->m_cols; c++) { auto val = param_2->m_data[r * param_2->m_stride + c]; - field_x268[param_1][c] += val; - field_x238[param_1][c].push_back(val); - if (field_x238[param_1][c].size() > field_x208[param_1]) { - field_x238[param_1][c].pop_front(); + field_x268[c] += val; + field_x238[c].push_back(val); + if (field_x238[c].size() > smooth_window) { + field_x238[c].pop_front(); } - param_2->m_data[r * param_2->m_stride + c] = field_x268[param_1][c] / field_x208[param_1]; + param_2->m_data[r * param_2->m_stride + c] = field_x268[c] / smooth_window; } } } + void UniversalDetectStream::ModelInfo::UpdateLicense(long param_2, float param_3) { + license_start = param_2; + license_days = param_3; + } + void UniversalDetectStream::UpdateLicense(int param_1, long param_2, float param_3) { - field_x190[param_1] = param_2; - field_x1a8[param_1] = param_3; + m_model_info[param_1].UpdateLicense(param_2, param_3); } void UniversalDetectStream::UpdateModel() const { WriteHotwordModel(true, m_options.model_str); } + void UniversalDetectStream::KeyWordInfo::WriteKeyword(bool binary, std::ostream* os) const { + WriteToken(binary, "", os); + WriteIntegerVector(binary, field_x88, os); + WriteToken(binary, "", os); + WriteBasicType(binary, sensitivity, os); + WriteToken(binary, "", os); + WriteBasicType(binary, search_method, os); + WriteToken(binary, "", os); + WriteBasicType(binary, search_neighbour, os); + WriteToken(binary, "", os); + WriteIntegerVector(binary, search_mask, os); + WriteToken(binary, "", os); + Vector tvec; + tvec.Resize(search_floor.size()); + // TODO: This could be a memcpy, or even better implement writing for vector + for (size_t i = 0; i < tvec.m_size; i++) { + tvec.m_data[i] = search_floor[i]; + } + tvec.Write(binary, os); + WriteToken(binary, "", os); + WriteBasicType(binary, search_max, os); + WriteToken(binary, "", os); + WriteBasicType(binary, field_x1d8, os); + WriteToken(binary, "", os); + WriteBasicType(binary, duration_pass, os); + WriteToken(binary, "", os); + WriteBasicType(binary, floor_pass, os); + } + + void UniversalDetectStream::ModelInfo::WriteHotwordModel(bool binary, std::ostream* os) const { + WriteToken(binary, "", os); + WriteToken(binary, "", os); + WriteBasicType(binary, license_start, os); + WriteToken(binary, "", os); + WriteBasicType(binary, license_days, os); + WriteToken(binary, "", os); + WriteToken(binary, "", os); + WriteBasicType(binary, smooth_window, os); + WriteToken(binary, "", os); + WriteBasicType(binary, slide_window, os); + WriteToken(binary, "", os); + WriteBasicType(binary, keywords.size(), os); + for (size_t kw = 0; kw < keywords.size(); kw++) { + keywords[kw].WriteKeyword(binary, os); + } + WriteToken(binary, "", os); + network.Write(binary, os); + } + void UniversalDetectStream::WriteHotwordModel(bool binary, const std::string& filename) const { std::vector parts; SplitStringToVector(filename, global_snowboy_string_delimiter, &parts); for (size_t file = 0; file < parts.size(); file++) { Output out{parts[file], binary}; auto os = out.Stream(); - WriteToken(binary, "", os); - WriteToken(binary, "", os); - WriteBasicType(binary, field_x190[file], os); - WriteToken(binary, "", os); - WriteBasicType(binary, field_x1a8[file], os); - WriteToken(binary, "", os); - WriteToken(binary, "", os); - WriteBasicType(binary, field_x208[file], os); - WriteToken(binary, "", os); - WriteBasicType(binary, field_x220[file], os); - WriteToken(binary, "", os); - WriteBasicType(binary, field_x88[file].size(), os); - for (size_t kw = 0; kw < field_x88[file].size(); kw++) { - WriteToken(binary, "", os); - WriteIntegerVector(binary, field_x88[file][kw], os); - WriteToken(binary, "", os); - WriteBasicType(binary, field_xa0[file][kw], os); - WriteToken(binary, "", os); - WriteBasicType(binary, field_xe8[file][kw], os); - WriteToken(binary, "", os); - WriteBasicType(binary, field_x100[file][kw], os); - WriteToken(binary, "", os); - WriteIntegerVector(binary, field_x148[file][kw], os); - WriteToken(binary, "", os); - Vector tvec; - tvec.Resize(field_x160[file][kw].size()); - // TODO: This could be a memcpy, or even better implement writing for vector - for (size_t i = 0; i < tvec.m_size; i++) { - tvec.m_data[i] = field_x160[file][kw][i]; - } - tvec.Write(binary, os); - WriteToken(binary, "", os); - WriteBasicType(binary, field_x178[file][kw], os); - WriteToken(binary, "", os); - WriteBasicType(binary, field_x1d8[file][kw], os); - WriteToken(binary, "", os); - WriteBasicType(binary, field_x118[file][kw], os); - WriteToken(binary, "", os); - WriteBasicType(binary, field_x130[file][kw], os); - } - WriteToken(binary, "", os); - field_x70[file].Write(binary, os); + m_model_info[file].WriteHotwordModel(binary, os); } } diff --git a/lib/universal-detect-stream.h b/lib/universal-detect-stream.h index d4bc0a5..0aba0a2 100644 --- a/lib/universal-detect-stream.h +++ b/lib/universal-detect-stream.h @@ -2,6 +2,7 @@ #include #include #include +#include #include #include @@ -22,13 +23,11 @@ namespace snowboy { bool debug_mode; void Register(const std::string&, OptionsItf*); }; - static_assert(sizeof(UniversalDetectStreamOptions) == 0x40); struct UniversalDetectStream : StreamItf { struct PieceInfo { char unknown[12]; }; - static_assert(sizeof(PieceInfo) == 0xc); UniversalDetectStreamOptions m_options; int field_x58; @@ -38,52 +37,67 @@ namespace snowboy { bool field_x68; int field_x6c; - // TODO: This really needs a refactor asap once we can... - // Mommy, I am scared... - // Whoever though that 25 vectors (of vectors of vectors...) was - // a good idea instead of just putting them into an object - // or at least a struct should never touch a computer again. - // And for goods sake no C++ code.... - std::vector field_x70; - // Kw <= unsure what this means - std::vector>> field_x88; - // Kw Sensitivity - std::vector> field_xa0; - // Kw High Sensitivity - std::vector> field_xb8; - std::vector> field_xd0; - // Kw Search Method - std::vector> field_xe8; - // Kw Search Neighbour - std::vector> field_x100; - // Kw DurationPass - std::vector> field_x118; - // Kw FloorPass - std::vector> field_x130; - // Kw SearchMask - std::vector>> field_x148; - // Kw SearchFloor - std::vector>> field_x160; - // Kw SearchMax - std::vector> field_x178; - // License start - std::vector field_x190; - // License days - std::vector field_x1a8; - std::vector> field_x1c0; - // Kw NumPieces - std::vector> field_x1d8; - std::vector>>> field_x1f0; - // Smooth window - std::vector field_x208; - // Slide window - std::vector field_x220; - std::vector>> field_x238; - std::vector>> field_x250; - std::vector> field_x268; - std::vector> field_x280; - std::vector> field_x298; - std::vector> field_x2b0; + struct KeyWordInfo { + // Kw <= unsure what this means + std::vector field_x88; + // Kw Sensitivity + float sensitivity; + // Kw High Sensitivity + float high_sensitivity; + int hotword_id; + // Kw Search Method + int search_method; // TODO: This could be an enum + // Kw Search Neighbour + int search_neighbour; + // Kw DurationPass + int duration_pass; + // Kw FloorPass + int floor_pass; + // Kw SearchMask + std::vector search_mask; + // Kw SearchFloor + std::vector search_floor; + // Kw SearchMax + bool search_max; + int field_x1c0; + // Kw NumPieces + int field_x1d8; + bool field_x280; + int field_x298; + + void ReadKeyword(bool binary, std::istream* is, int slide_window); + void WriteKeyword(bool binary, std::ostream* os) const; + }; + + struct ModelInfo { + Nnet network; + std::vector keywords; + // License start + long license_start; + // License days + float license_days; + + std::vector>> field_x1f0; + // Smooth window + int smooth_window; + // Slide window + int slide_window; + std::vector> field_x238; + std::vector> field_x250; + std::vector field_x268; + std::vector field_x2b0; + + void CheckLicense() const; + void SmoothPosterior(Matrix* param_2); + float HotwordNaiveSearch(int) const; + int NumHotwords() const; + void ReadHotwordModel(bool binary, std::istream* is, int num_repeats, int* hotword_id); + void WriteHotwordModel(bool binary, std::ostream* os) const; + void ResetDetection(); + void UpdateLicense(long, float); + }; + + std::vector m_model_info; UniversalDetectStream(const UniversalDetectStreamOptions& options); virtual int Read(Matrix* mat, std::vector* info) override; @@ -91,11 +105,9 @@ namespace snowboy { virtual std::string Name() const override; virtual ~UniversalDetectStream(); - void CheckLicense(int) const; float GetHotwordPosterior(int, int, int); std::string GetSensitivity() const; float HotwordDtwSearch(int, int) const; - float HotwordNaiveSearch(int, int) const; float HotwordPiecewiseSearch(int, int) const; float HotwordViterbiSearch(int, int) const; float HotwordViterbiSearch(int, int, int, const PieceInfo&) const; @@ -111,10 +123,8 @@ namespace snowboy { void SetSensitivity(const std::string&); void SetSlideWindowSize(const std::string&); void SetSmoothWindowSize(const std::string&); - void SmoothPosterior(int, Matrix*); void UpdateLicense(int, long, float); void UpdateModel() const; void WriteHotwordModel(bool binary, const std::string& filename) const; }; - static_assert(sizeof(UniversalDetectStream) == 0x2c8); } // namespace snowboy \ No newline at end of file diff --git a/lib/vad-lib.h b/lib/vad-lib.h index 9cf391d..34fcb23 100644 --- a/lib/vad-lib.h +++ b/lib/vad-lib.h @@ -6,7 +6,6 @@ namespace snowboy { int min_non_voice_frames; int min_voice_frames; }; - static_assert(sizeof(VadStateOptions) == 8); enum VoiceType { VT_0, VT_1, VT_2 }; @@ -23,5 +22,4 @@ namespace snowboy { void Reset(); void GetVoiceStates(const std::vector&, std::vector*); }; - static_assert(sizeof(VadState) == 24); } // namespace snowboy \ No newline at end of file diff --git a/lib/vad-state-stream.h b/lib/vad-state-stream.h index d65003e..5555ae2 100644 --- a/lib/vad-state-stream.h +++ b/lib/vad-state-stream.h @@ -13,7 +13,6 @@ namespace snowboy { int extra_frame_adjust; void Register(const std::string&, OptionsItf*); }; - static_assert(sizeof(VadStateStreamOptions) == 0x10); struct VadStateStream : StreamItf { const VadStateStreamOptions m_options; int field_x28; @@ -37,5 +36,4 @@ namespace snowboy { virtual std::string Name() const override; virtual ~VadStateStream(); }; - static_assert(sizeof(VadStateStream) == 0xa8); } // namespace snowboy \ No newline at end of file diff --git a/lib/vector-wrapper.cpp b/lib/vector-wrapper.cpp index c456b72..65655ff 100644 --- a/lib/vector-wrapper.cpp +++ b/lib/vector-wrapper.cpp @@ -239,26 +239,6 @@ namespace snowboy { return SubVector(*this, param_1, param_2); } - void VectorBase::Read(bool binary, bool add, std::istream* is) { - // TODO: Since reading is always the same, couldn't we just drop it in here ? - Vector tmp; - tmp.Resize(m_size, MatrixResizeType::kSetZero); - tmp.Read(binary, false, is); - if (tmp.m_size != m_size) { - SNOWBOY_ERROR() << "Failed to read Vector: size missmatch (" << tmp.m_size << " v.s. " << m_size << ")."; - return; - } - if (add) { - AddVec(1.0f, tmp); - } else { - CopyFromVec(tmp); - } - } - - void VectorBase::Read(bool binary, std::istream* is) { - Read(binary, false, is); - } - void VectorBase::Scale(float factor) { cblas_sscal(m_size, factor, m_data, 1); } @@ -306,71 +286,61 @@ namespace snowboy { if (size <= m_size) { m_size = size; if (resize == MatrixResizeType::kSetZero) Set(0.0f); - } else { - // The new size is larger than we currently are, so we need to reallocate. + return; + } + + // The new size is larger than we currently are, so we need to reallocate. #if HAS_MALLOC_USABLE_SIZE - auto usable = malloc_usable_size(m_data); - if (usable >= size * sizeof(float)) { - if (resize == MatrixResizeType::kSetZero) { - memset(&m_data[m_size], 0, usable - m_size * sizeof(float)); - } - m_size = size; - return; + auto usable = malloc_usable_size(m_data); + if (usable >= size * sizeof(float)) { + if (resize == MatrixResizeType::kSetZero) { + memset(&m_data[m_size], 0, usable - m_size * sizeof(float)); } + m_size = size; + return; + } #endif - // We dont have usable size or the allocated block was to small - if (resize == MatrixResizeType::kCopyData) { - // Since we would copy it anyway we can just call realloc and maybe save copying (e.g. if the next block is free). - auto ptr = static_cast(realloc(m_data, size * sizeof(float))); - if (ptr == nullptr) throw std::bad_alloc(); - if (ptr != m_data && (reinterpret_cast(ptr) % 16) != 0) { - // realloc moved the data but the new buffer is not aligned correctly - allocs++; - frees++; - free(ptr); - allocs++; - ptr = static_cast(SnowboyMemalign(16, size * sizeof(float))); - if (ptr == nullptr) throw std::bad_alloc(); - memcpy(ptr, m_data, sizeof(float) * m_size); - } - m_data = ptr; - memset(&m_data[m_size], 0, (size - m_size) * sizeof(float)); - } else if (resize == MatrixResizeType::kSetZero) { + // We dont have usable size or the allocated block was to small + if (resize == MatrixResizeType::kCopyData) { + // Since we would copy it anyway we can just call realloc and maybe save copying (e.g. if the next block is free). + auto ptr = static_cast(realloc(m_data, size * sizeof(float))); + if (ptr == nullptr) throw std::bad_alloc(); + if (ptr != m_data && (reinterpret_cast(ptr) % 16) != 0) { + // realloc moved the data but the new buffer is not aligned correctly allocs++; - auto ptr = static_cast(SnowboyMemalign(16, size * sizeof(float))); - if (ptr == nullptr) throw std::bad_alloc(); - if (m_data) { - frees++; - free(m_data); - } - memset(ptr, 0, size * sizeof(float)); - m_data = ptr; - } else { + frees++; + free(ptr); allocs++; - auto ptr = static_cast(SnowboyMemalign(16, size * sizeof(float))); + ptr = static_cast(SnowboyMemalign(16, size * sizeof(float))); if (ptr == nullptr) throw std::bad_alloc(); - if (m_data) { - frees++; - free(m_data); - } - m_data = ptr; + memcpy(ptr, m_data, sizeof(float) * m_size); } - m_size = size; - } - } - - void Vector::AllocateVectorMemory(int size) { - if (size == 0) { - m_data = nullptr; + m_data = ptr; + memset(&m_data[m_size], 0, (size - m_size) * sizeof(float)); + } else if (resize == MatrixResizeType::kSetZero) { + allocs++; + auto ptr = static_cast(SnowboyMemalign(16, size * sizeof(float))); + if (ptr == nullptr) throw std::bad_alloc(); + if (m_data) { + frees++; + free(m_data); + } + memset(ptr, 0, size * sizeof(float)); + m_data = ptr; } else { - m_data = static_cast(SnowboyMemalign(16, size << 2)); - if (m_data == nullptr) throw std::bad_alloc(); allocs++; + auto ptr = static_cast(SnowboyMemalign(16, size * sizeof(float))); + if (ptr == nullptr) throw std::bad_alloc(); + if (m_data) { + frees++; + free(m_data); + } + m_data = ptr; } m_size = size; } - void Vector::ReleaseVectorMemory() { + Vector::~Vector() { if (m_data) { SnowboyMemalignFree(m_data); frees++; @@ -393,6 +363,7 @@ namespace snowboy { void Vector::Read(bool binary, bool add, std::istream* is) { if (!binary) { + // TODO: Is this still accurate ? SNOWBOY_ERROR() << "Not implemented"; ExpectToken(binary, "[", is); uint32_t i = 0; @@ -479,7 +450,7 @@ namespace snowboy { SubVector::SubVector(const VectorBase& parent, int offset, int size) { m_data = parent.m_data + offset; - m_size = size; // TODO: std::min(parent.m_size - offset, size); + m_size = std::min(parent.m_size - offset, size); } SubVector::SubVector(const MatrixBase& parent, int row) { diff --git a/lib/vector-wrapper.h b/lib/vector-wrapper.h index 1b867cf..0a75a0c 100644 --- a/lib/vector-wrapper.h +++ b/lib/vector-wrapper.h @@ -10,6 +10,14 @@ namespace snowboy { uint32_t m_size{0}; float* m_data{nullptr}; + float* begin() const noexcept { return m_data; } + float* end() const noexcept { return m_data + m_size; } + size_t size() const noexcept { return m_size; } + float* data() const noexcept { return m_data; } + float& operator[](size_t index) const noexcept { return m_data[index]; } + float& operator()(size_t index) const noexcept { return m_data[index]; } + bool empty() const noexcept { return size() == 0; } + void Add(float x); void AddDiagMat2(float, const MatrixBase&, MatrixTransposeType, float); void AddMatVec(float, const MatrixBase&, MatrixTransposeType, const VectorBase&, float); @@ -34,8 +42,6 @@ namespace snowboy { float Norm(float) const; SubVector Range(int, int) const; SubVector Range(int, int); - void Read(bool, bool, std::istream*); - void Read(bool, std::istream*); void Scale(float); void Set(float); void SetRandomGaussian(); @@ -57,9 +63,7 @@ namespace snowboy { } void Resize(int size, MatrixResizeType resize = MatrixResizeType::kSetZero); - void AllocateVectorMemory(int size); - void ReleaseVectorMemory(); // NOTE: Called destroy in kaldi - ~Vector() { ReleaseVectorMemory(); } + ~Vector(); Vector& operator=(const Vector& other); Vector& operator=(const VectorBase& other); @@ -69,7 +73,7 @@ namespace snowboy { } void Read(bool, bool, std::istream*); - void Read(bool, std::istream*); // Read(p1, false, p2); + void Read(bool, std::istream*); void Swap(Vector* other); void RemoveElement(int index); @@ -77,11 +81,9 @@ namespace snowboy { static void ResetAllocStats(); }; struct SubVector : VectorBase { + // TODO: Those int should be size_t or at least uint SubVector(const VectorBase& parent, int, int); SubVector(const MatrixBase& parent, int); SubVector(const SubVector& other); }; - static_assert(sizeof(VectorBase) == 0x10); - static_assert(sizeof(Vector) == 0x10); - static_assert(sizeof(SubVector) == 0x10); } // namespace snowboy \ No newline at end of file diff --git a/lib/types.h b/lib/wave-header.h similarity index 100% rename from lib/types.h rename to lib/wave-header.h diff --git a/test/EnrollTest.cpp b/test/EnrollTest.cpp index a1a69df..93df2c9 100644 --- a/test/EnrollTest.cpp +++ b/test/EnrollTest.cpp @@ -27,9 +27,6 @@ TEST(EnrollTest, PersonalEnroll) { auto res = enroll.RunEnrollment(str_data); ASSERT_EQ(res, 0); } - for (auto& e : enroll.enroll_pipeline_->m_templateEnrollStream->field_x38.m_templates) { - std::cout << e.m_rows << "x" << e.m_cols << " hash=" << hash(e) << std::endl; - } ASSERT_TRUE(file_exists("temp_enroll_model.pmdl")); ASSERT_EQ(hash(enroll.enroll_pipeline_->m_templateEnrollStream->field_x38.m_templates.front()), 928553); // TODO: I would really like to do a md5 of the file instead but due to rounding errors (and dithering) thats not an option @@ -55,9 +52,6 @@ TEST(EnrollTest, PersonalEnroll2) { auto res = enroll.RunEnrollment(str_data); ASSERT_EQ(res, 0); } - for (auto& e : enroll.enroll_pipeline_->m_templateEnrollStream->field_x38.m_templates) { - std::cout << e.m_rows << "x" << e.m_cols << " hash=" << hash(e) << std::endl; - } ASSERT_TRUE(file_exists("temp_enroll_model.pmdl")); ASSERT_EQ(hash(enroll.enroll_pipeline_->m_templateEnrollStream->field_x38.m_templates.front()), 928522); // TODO: I would really like to do a md5 of the file instead but due to rounding errors (and dithering) thats not an option diff --git a/test/helper.cpp b/test/helper.cpp index 177b653..d93af98 100644 --- a/test/helper.cpp +++ b/test/helper.cpp @@ -1,5 +1,7 @@ +#include #include #include +#include #include #include #include @@ -138,4 +140,191 @@ std::string md5sum(const std::string& data) { std::string md5sum_file(const std::string& file) { auto content = read_file(file); return md5sum(content); -} \ No newline at end of file +} + +MemoryChecker::snapshot MemoryChecker::g_global{}; + +void MemoryChecker::stacktrace::capture() { + auto res = ::backtrace(trace, 50); + for (int i = res; i < 50; i++) + trace[i] = nullptr; +} + +MemoryChecker::snapshot MemoryChecker::calculate_difference() const noexcept { + snapshot res; + res.num_malloc = g_global.num_malloc - m_start.num_malloc; + res.num_malloc_failed = g_global.num_malloc_failed - m_start.num_malloc_failed; + res.num_free = g_global.num_free - m_start.num_free; + res.num_realloc = g_global.num_realloc - m_start.num_realloc; + res.num_realloc_failed = g_global.num_realloc_failed - m_start.num_realloc_failed; + res.num_realloc_moved = g_global.num_realloc_moved - m_start.num_realloc_moved; + res.num_memalign = g_global.num_memalign - m_start.num_memalign; + res.num_memalign_failed = g_global.num_memalign_failed - m_start.num_memalign_failed; + res.num_chunks_allocated = g_global.num_chunks_allocated - m_start.num_chunks_allocated; + res.num_chunks_allocated_max = g_global.num_chunks_allocated_max - m_start.num_chunks_allocated_max; + res.num_bytes_allocated = g_global.num_bytes_allocated - m_start.num_bytes_allocated; + res.num_bytes_allocated_max = g_global.num_bytes_allocated_max - m_start.num_bytes_allocated_max; + res.bt_max_chunks = g_global.bt_max_chunks; + res.bt_max_bytes = g_global.bt_max_bytes; + return res; +} + +extern "C" void* __libc_malloc(size_t); +extern "C" void* __libc_realloc(void*, size_t); +extern "C" void __libc_free(void*); +extern "C" void* __libc_memalign(size_t alignment, size_t size); + +static thread_local bool in_memchecker = false; +void* MemoryChecker::mc_malloc(size_t size, const void* caller) { + if (in_memchecker) return __libc_malloc(size); + in_memchecker = true; + void* ptr = __libc_malloc(size); + + g_global.num_malloc++; + if (ptr == nullptr) { + g_global.num_malloc_failed++; + in_memchecker = false; + return ptr; + } + + size = malloc_usable_size(ptr); + + g_global.num_bytes_allocated += size; + g_global.num_chunks_allocated++; + if (g_global.num_chunks_allocated > g_global.num_chunks_allocated_max) { + g_global.num_chunks_allocated_max = g_global.num_chunks_allocated; + g_global.bt_max_chunks.capture(); + } + if (g_global.num_bytes_allocated > g_global.num_bytes_allocated_max) { + g_global.num_bytes_allocated_max = g_global.num_bytes_allocated; + g_global.bt_max_bytes.capture(); + } + in_memchecker = false; + return ptr; +} + +void* MemoryChecker::mc_realloc(void* cptr, size_t size, const void* caller) { + if (in_memchecker) return __libc_realloc(cptr, size); + in_memchecker = true; + auto oldsize = malloc_usable_size(cptr); + void* ptr = __libc_realloc(cptr, size); + + g_global.num_realloc++; + if (ptr == nullptr) { + g_global.num_realloc_failed++; + in_memchecker = false; + return ptr; + } + if (cptr == ptr) { + g_global.num_realloc_moved++; + } + + g_global.num_bytes_allocated += (static_cast(size) - static_cast(oldsize)); + if (g_global.num_chunks_allocated > g_global.num_chunks_allocated_max) { + g_global.num_chunks_allocated_max = g_global.num_chunks_allocated; + g_global.bt_max_chunks.capture(); + } + if (g_global.num_bytes_allocated > g_global.num_bytes_allocated_max) { + g_global.num_bytes_allocated_max = g_global.num_bytes_allocated; + g_global.bt_max_bytes.capture(); + } + in_memchecker = false; + return ptr; +} + +void MemoryChecker::mc_free(void* ptr, const void* caller) { + if (in_memchecker) return __libc_free(ptr); + if (ptr == nullptr) return; + in_memchecker = true; + auto oldsize = malloc_usable_size(ptr); + + __libc_free(ptr); + + g_global.num_free++; + g_global.num_bytes_allocated -= oldsize; + g_global.num_chunks_allocated--; + in_memchecker = false; +} + +void* MemoryChecker::mc_memalign(size_t alignment, size_t size, const void* caller) { + void* ptr = __libc_memalign(alignment, size); + if (in_memchecker) return ptr; + in_memchecker = true; + + g_global.num_memalign++; + if (ptr == nullptr) { + g_global.num_memalign_failed++; + in_memchecker = false; + return ptr; + } + + size = malloc_usable_size(ptr); + + g_global.num_bytes_allocated += size; + g_global.num_chunks_allocated++; + if (g_global.num_chunks_allocated > g_global.num_chunks_allocated_max) { + g_global.num_chunks_allocated_max = g_global.num_chunks_allocated; + g_global.bt_max_chunks.capture(); + } + if (g_global.num_bytes_allocated > g_global.num_bytes_allocated_max) { + g_global.num_bytes_allocated_max = g_global.num_bytes_allocated; + g_global.bt_max_bytes.capture(); + } + in_memchecker = false; + return ptr; +} + +std::ostream& operator<<(std::ostream& str, const MemoryChecker::stacktrace& o) { + int n = 0; + for (; n < 49; n++) + if (o.trace[n + 1] == nullptr) break; + auto strings = backtrace_symbols(o.trace, n); + if (strings == NULL) { + str << ""; + return str; + } + for (int j = 0; j < n; j++) + str << strings[j] << "\n"; + free(strings); + return str; +} + +std::ostream& operator<<(std::ostream& str, const MemoryChecker& o) { + auto diff = o.calculate_difference(); + str << "==== Memory report ====\n"; + str << "num_malloc = " << diff.num_malloc << "\n"; + str << "num_malloc_failed = " << diff.num_malloc_failed << "\n"; + str << "num_free = " << diff.num_free << "\n"; + str << "num_realloc = " << diff.num_realloc << "\n"; + str << "num_realloc_failed = " << diff.num_realloc_failed << "\n"; + str << "num_realloc_moved = " << diff.num_realloc_moved << "\n"; + str << "num_memalign = " << diff.num_memalign << "\n"; + str << "num_memalign_failed = " << diff.num_memalign_failed << "\n"; + str << "num_chunks_allocated = " << diff.num_chunks_allocated << "\n"; + str << "num_chunks_allocated_max = " << diff.num_chunks_allocated_max << "\n"; + str << "num_bytes_allocated = " << diff.num_bytes_allocated << "\n"; + str << "num_bytes_allocated_max = " << diff.num_bytes_allocated_max << "\n"; + str << "max_chunks_at:\n" + << diff.bt_max_chunks; + str << "max_bytes_at:\n" + << diff.bt_max_bytes; + return str; +} + +extern "C" void* malloc(size_t size) { + return MemoryChecker::mc_malloc(size, __builtin_return_address(0)); +} +extern "C" void* realloc(void* ptr, size_t size) { + return MemoryChecker::mc_realloc(ptr, size, __builtin_return_address(0)); +} +extern "C" void* memalign(size_t alignment, size_t size) { + return MemoryChecker::mc_memalign(alignment, size, __builtin_return_address(0)); +} +extern "C" void free(void* ptr) { + MemoryChecker::mc_free(ptr, __builtin_return_address(0)); +} +extern "C" int posix_memalign(void** memptr, size_t alignment, size_t size) { + *memptr = memalign(alignment, size); + if (!*memptr) return ENOMEM; + return 0; +} diff --git a/test/helper.h b/test/helper.h index a89b968..13fe124 100644 --- a/test/helper.h +++ b/test/helper.h @@ -49,4 +49,48 @@ std::ostream& operator<<(std::ostream& s, const std::vector& o) { s << " " << e; s << " }"; return s; -} \ No newline at end of file +} + +struct MemoryChecker { + struct stacktrace { + void* trace[50]; + void capture(); + }; + + struct snapshot { + ssize_t num_malloc = 0; + ssize_t num_malloc_failed = 0; + ssize_t num_free = 0; + ssize_t num_realloc = 0; + ssize_t num_realloc_failed = 0; + ssize_t num_realloc_moved = 0; + ssize_t num_memalign = 0; + ssize_t num_memalign_failed = 0; + ssize_t num_chunks_allocated = 0; + ssize_t num_chunks_allocated_max = 0; + ssize_t num_bytes_allocated = 0; + ssize_t num_bytes_allocated_max = 0; + stacktrace bt_max_chunks; + stacktrace bt_max_bytes; + }; + static snapshot g_global; + + snapshot m_start; + + MemoryChecker() { + m_start = g_global; + } + + ~MemoryChecker() { + } + + snapshot calculate_difference() const noexcept; + + static void* mc_malloc(size_t size, const void* caller); + static void* mc_realloc(void* cptr, size_t size, const void* caller); + static void mc_free(void* ptr, const void* caller); + static void* mc_memalign(size_t alignment, size_t size, const void* caller); +}; + +std::ostream& operator<<(std::ostream& str, const MemoryChecker::stacktrace& o); +std::ostream& operator<<(std::ostream& str, const MemoryChecker& o);