Skip to content

Commit

Permalink
Cleanup and more examples (#11)
Browse files Browse the repository at this point in the history
* Proper access modifiers for Vector
* More cleanup and enable lto
* Add examples using real audio
* Add libpulse-dev to dependencies
  • Loading branch information
Thalhammer authored Jun 6, 2021
1 parent f6cbed7 commit c51cd4e
Show file tree
Hide file tree
Showing 43 changed files with 1,047 additions and 335 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/cmake.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ jobs:
steps:
- uses: actions/checkout@v2

- name: Install libatlas-base-dev
run: sudo apt-get install libatlas-base-dev
- name: Install dependencies
run: sudo apt-get install libatlas-base-dev libpulse-dev

- name: Configure CMake
# Configure CMake in a 'build' subdirectory. `CMAKE_BUILD_TYPE` is only required if you are using a single-configuration generator such as make.
Expand Down
9 changes: 8 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,14 @@ endif()

list(APPEND CMAKE_MODULE_PATH ${CMAKE_CURRENT_LIST_DIR}/cmake)

option(SNOWMAN_CXX11_COMPAT "Build library with C++11 strings disabled to be binary compatible with the original release." ON)
# Enable Link-Time Optimization
if(NOT ("${CMAKE_BUILD_TYPE}" STREQUAL "Debug"))
include(CheckIPOSupported)
check_ipo_supported(RESULT LTOAvailable)
endif()
if(LTOAvailable)
message("Link-time optimization enabled")
endif()
option(SNOWMAN_BUILD_APPS "Build helper applications like enroll or cut" ON)
option(SNOWMAN_BUILD_APPS_STATIC "Build apps statically" OFF)
option(SNOWMAN_BUILD_TESTS "Build unit tests (requires gtest and openssl)" ON)
Expand Down
25 changes: 25 additions & 0 deletions apps/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,36 @@ add_executable(enroll
target_include_directories(enroll PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
target_link_libraries(enroll crypto snowboy)

add_executable(detect-live
helper.cpp
detect-live.cpp
)
target_include_directories(detect-live PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
target_link_libraries(detect-live pulse-simple pulse snowboy)

add_executable(enroll-live
helper.cpp
enroll-live.cpp
)
target_include_directories(enroll-live PRIVATE ${CMAKE_CURRENT_SOURCE_DIR})
target_link_libraries(enroll-live pulse-simple pulse snowboy)

if(SNOWMAN_BUILD_APPS_STATIC)
IF("${CMAKE_BUILD_TYPE}" MATCHES "^(Debug)\$")
message(WARNING "Ignored BUILD_APPS_STATIC because it is incompatible with debug mode")
else()
message(STATUS "Linking apps statically")
target_link_libraries(cut -static)
target_link_libraries(enroll -static)
#target_link_libraries(detect-live -static)
#target_link_libraries(enroll-live -static)
endif()
endif()

if(LTOAvailable)
message(STATUS "LTO enabled for apps")
set_property(TARGET cut PROPERTY INTERPROCEDURAL_OPTIMIZATION TRUE)
set_property(TARGET enroll PROPERTY INTERPROCEDURAL_OPTIMIZATION TRUE)
set_property(TARGET detect-live PROPERTY INTERPROCEDURAL_OPTIMIZATION TRUE)
set_property(TARGET enroll-live PROPERTY INTERPROCEDURAL_OPTIMIZATION TRUE)
endif()
37 changes: 37 additions & 0 deletions apps/detect-live.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
#include <helper.h>
#include <iostream>
#include <snowboy-detect.h>
#include <pulseaudio.h>

const static auto root = detect_project_root();

namespace pa = pulseaudio;

int main(int argc, const char** argv) try
{
std::string model = root + "resources/models/snowboy.umdl";
if(argc > 1) model = argv[1];
pa::simple_record_stream audio_in{"Microphone input"};
pa::simple_playback_stream audio_out{"Ding"};

snowboy::SnowboyDetect detector(root + "resources/common.res", model);
detector.SetSensitivity("0.3");
detector.SetAudioGain(1.0);
detector.ApplyFrontend(true);

auto ding = read_sample_file(root + "resources/dong.wav");
std::vector<short> samples;
while (true) {
audio_in.read(samples);
auto s = detector.RunDetection(samples.data(), samples.size(), false);
//std::cout << "\r \r" << s << std::flush;
if (s > 0) {
std::cout << "a " << s << std::endl;
audio_out.write(ding);
}
}
return 0;
} catch (const std::exception& e) {
std::cerr << "Error: " << e.what() << std::endl;
return -1;
}
64 changes: 64 additions & 0 deletions apps/enroll-live.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
#include <helper.h>
#include <iostream>
#include <snowboy-detect.h>
#include <pulseaudio.h>

const static auto root = detect_project_root();

namespace pa = pulseaudio;

std::vector<short> record_word(pa::simple_record_stream& audio_in) {
snowboy::SnowboyVad vad{root + "resources/common.res"};
std::vector<short> samples;
std::vector<short> result;
audio_in.read(samples);
// Skip leading silence
while (vad.RunVad(samples.data(), samples.size()) == -2) {
std::cout << "S" << std::flush;
audio_in.read(samples);
}
// Keep samples while there is voice
result = samples;
while (vad.RunVad(samples.data(), samples.size()) == 0) {
std::cout << "A" << std::flush;
audio_in.read(samples);
result.insert(result.end(), samples.begin(), samples.end());
}
std::cout << "D\n" << std::flush;
return result;
}

int main(int argc, const char** argv) try
{
std::string output = "model.pmdl", language = "en";
int64_t num_records = 3;
option_parser parser;
parser.option("--output", &output).set_shortname("-o").set_description("Output filename for the model");
parser.option("--language", &language).set_shortname("-l").set_description("Language of the enrolled word");
parser.option("--nrecs", &num_records).set_min(3).set_shortname("-n").set_required(true).set_description("Number of hotword samples to record for the model");
parser.parse(argc, argv);

snowboy::SnowboyPersonalEnroll enroll{root + "resources/pmdl/en/personal_enroll.res", "model.pmdl"};
snowboy::SnowboyTemplateCut cut{root + "resources/pmdl/en/personal_enroll.res"};
pa::simple_record_stream audio_in{"Microphone input", enroll.SampleRate(), enroll.NumChannels()};

for(int64_t i=1; i<=num_records; i++) {
std::cout << "[" << i << "/" << num_records << "] ";
auto sample = record_word(audio_in);
std::cout << "[" << i << "/" << num_records << "] Got sample with length = " << (static_cast<float>(sample.size())/16000.0f) << "s (" << sample.size() << " samples)" << std::endl;
int new_size;
if(cut.CutTemplate(sample.data(), sample.size(), sample.data(), &new_size) != 0)
throw std::runtime_error("Failed to cut template");
sample.resize(new_size);
std::cout << "[" << i << "/" << num_records << "] length after cutting = " << (static_cast<float>(sample.size())/16000.0f) << "s (" << sample.size() << " samples)" << std::endl;
auto res = enroll.RunEnrollment(sample.data(), sample.size());
if (res != 0) {
std::cerr << "Training failed with error " << res << std::endl;
return -1;
}
}
return 0;
} catch (const std::exception& e) {
std::cerr << "Error: " << e.what() << std::endl;
return -1;
}
2 changes: 1 addition & 1 deletion apps/enroll.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ bool parse_args(int argc, const char** argv, std::string& output, std::vector<st
bool no_cut = false;
parser.option("--no-cut-recordings", &no_cut).set_shortname("-nc").set_description("Do not cut recordings before running enrollment");
parser.option("--output", &output).set_shortname("-o").set_description("Output filename for the model");
parser.option("--language", &output).set_shortname("-l").set_description("Language of the enrolled word");
parser.option("--language", &lang).set_shortname("-l").set_description("Language of the enrolled word");
parser.option("--recording", &recordings).set_shortname("-r").set_required(true).set_description("Recording to enroll");
bool print_help = false;
parser.option("--help", &print_help).set_shortname("-h").set_description("Print help");
Expand Down
27 changes: 27 additions & 0 deletions apps/helper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,26 @@ void string_list_option::parse(arg_iterator& it) {
value_ptr->push_back(it.take());
}

void int_option::parse(arg_iterator& it) {
if (!it.has_more()) {
throw std::runtime_error("missing argument for option " + longname);
}
*value_ptr = std::stoll(it.take());
if(*value_ptr < minimum || *value_ptr > maximum) {
throw std::runtime_error("value exceeds range");
}
}

int_option& int_option::set_min(int64_t m) noexcept {
this->minimum = m;
return *this;
}

int_option& int_option::set_max(int64_t m) noexcept {
this->maximum = m;
return *this;
}

option_parser::~option_parser() {
for (auto& e : options)
delete e;
Expand Down Expand Up @@ -253,6 +273,13 @@ string_list_option& option_parser::option(std::string longname, std::vector<std:
return *opt;
}

int_option& option_parser::option(std::string longname, int64_t* ptr) {
auto opt = new int_option(ptr);
opt->longname = longname;
options.push_back(opt);
return *opt;
}

std::vector<std::string> option_parser::parse(std::vector<std::string> a) {
arg_iterator it{a};
return parse(it);
Expand Down
11 changes: 11 additions & 0 deletions apps/helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,16 @@ struct string_list_option : basic_option<std::vector<std::string>> {
void parse(arg_iterator& it);
};

struct int_option : basic_option<int64_t> {
int64_t minimum = INT64_MIN;
int64_t maximum = INT64_MAX;
using basic_option::basic_option;
void parse(arg_iterator& it);

int_option& set_min(int64_t minimum) noexcept;
int_option& set_max(int64_t maximum) noexcept;
};

struct option_parser {
std::vector<option_base*> options;

Expand All @@ -73,6 +83,7 @@ struct option_parser {
bool_option& option(std::string longname, bool* ptr);
string_option& option(std::string longname, std::string* ptr);
string_list_option& option(std::string longname, std::vector<std::string>* ptr);
int_option& option(std::string longname, int64_t* ptr);
std::vector<std::string> parse(std::vector<std::string> a);
std::vector<std::string> parse(int argc, const char** argv);
std::vector<std::string> parse(const std::string& str);
Expand Down
Loading

0 comments on commit c51cd4e

Please sign in to comment.