Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify the code a bit using more C++ facilities #5

Merged
merged 2 commits into from
Dec 15, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
306 changes: 145 additions & 161 deletions bmap-writer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

#include <iostream>
#include <fstream>
#include <sstream>
#include <vector>
#include <iomanip>
#include <string>
Expand Down Expand Up @@ -104,21 +105,24 @@ bmap_t parseBMap(const std::string &filename) {
return bmapData;
}

void computeSHA256(const char *buffer, size_t size, char *output) {
std::string computeSHA256(const std::vector<char>& buffer, size_t size) {
EVP_MD_CTX *mdctx;
unsigned char hash[EVP_MAX_MD_SIZE];
unsigned int hash_len;

mdctx = EVP_MD_CTX_new();
EVP_DigestInit_ex(mdctx, EVP_sha256(), NULL);
EVP_DigestUpdate(mdctx, buffer, size);
EVP_DigestUpdate(mdctx, buffer.data(), size);
EVP_DigestFinal_ex(mdctx, hash, &hash_len);
EVP_MD_CTX_free(mdctx);

std::ostringstream output;
output << std::hex;
for (unsigned int i = 0; i < hash_len; ++i) {
sprintf(output + (i * 2), "%02x", hash[i]);
output << std::setfill('0') << std::setw(2) << static_cast<unsigned int>(hash[i]);
}
output[CHECKSUM_LENGTH] = 0;

return output.str();
}

int getCompressionType(const std::string &imageFile, std::string &compressionType) {
Expand Down Expand Up @@ -169,197 +173,177 @@ void printBufferHex(const char *buffer, size_t size) {
}

int BmapWriteImage(const std::string &imageFile, const bmap_t &bmap, const std::string &device, const std::string &compressionType) {
int dev_fd = open(device.c_str(), O_WRONLY | O_CREAT, S_IRUSR | S_IWUSR);
if (dev_fd < 0) {
std::cerr << "Unable to open or create target device" << std::endl;
return 1;
}

gzFile gzImg = nullptr;
lzma_stream lzmaStream = LZMA_STREAM_INIT;
std::vector<char> decBufferIn(DEC_BUFFER_SIZE);
size_t decHead = 0;
std::ifstream imgFile;
int dev_fd = -1;
int ret = 0;

if (compressionType == "gzip") {
gzImg = gzopen(imageFile.c_str(), "rb");
if (!gzImg) {
std::cerr << "Unable to open gzip image file" << std::endl;
close(dev_fd);
return 1;
}
} else if (compressionType == "xz") {
imgFile.open(imageFile, std::ios::binary);
if (!imgFile) {
std::cerr << "Unable to open xz image file" << std::endl;
close(dev_fd);
return 1;
}
lzma_ret ret = lzma_stream_decoder(&lzmaStream, UINT64_MAX, 0);
if (ret != LZMA_OK) {
std::cerr << "Failed to initialize lzma decoder: " << ret << std::endl;
close(dev_fd);
return 1;
try {
dev_fd = open(device.c_str(), O_WRONLY | O_CREAT, S_IRUSR | S_IWUSR);
if (dev_fd < 0) {
throw std::string("Unable to open or create target device");
}

lzmaStream.avail_in = 0;
} else if (compressionType == "none") {
imgFile.open(imageFile, std::ios::binary);
if (!imgFile) {
std::cerr << "Unable to open image file" << std::endl;
close(dev_fd);
return 1;
}
} else {
std::cerr << "Unsupported compression type" << std::endl;
close(dev_fd);
return 1;
}
if (compressionType == "gzip") {
gzImg = gzopen(imageFile.c_str(), "rb");
if (!gzImg) {
throw std::string("Unable to open gzip image file");
}
} else if (compressionType == "xz") {
imgFile.open(imageFile, std::ios::binary);
if (!imgFile) {
throw std::string("Unable to open xz image file");
}
lzma_ret ret = lzma_stream_decoder(&lzmaStream, UINT64_MAX, 0);
if (ret != LZMA_OK) {
throw std::string("Failed to initialize lzma decoder: ") + std::to_string(static_cast<unsigned int>(ret));
}

for (const auto &range : bmap.ranges) {
size_t startBlock, endBlock;
if (sscanf(range.range.c_str(), "%zu-%zu", &startBlock, &endBlock) == 1) {
endBlock = startBlock; // Handle single block range
lzmaStream.avail_in = 0;
} else if (compressionType == "none") {
imgFile.open(imageFile, std::ios::binary);
if (!imgFile) {
throw std::string("Unable to open image file");
}
} else {
throw std::string("Unsupported compression type ") + compressionType;
}
std::cout << "Processing Range: startBlock=" << startBlock << ", endBlock=" << endBlock << std::endl;

size_t bufferSize = (endBlock - startBlock + 1) * bmap.blockSize;
std::vector<char> buffer(bufferSize);
size_t outBytes = 0;

if (compressionType == "gzip") {
gzseek(gzImg, static_cast<off_t>(startBlock * bmap.blockSize), SEEK_SET);
int readBytes = gzread(gzImg, buffer.data(), static_cast<unsigned int>(bufferSize));
if (readBytes < 0) {
std::cerr << "Failed to read from gzip image file" << std::endl;
close(dev_fd);
gzclose(gzImg);
return 1;
for (const auto &range : bmap.ranges) {
size_t startBlock, endBlock;
if (sscanf(range.range.c_str(), "%zu-%zu", &startBlock, &endBlock) == 1) {
endBlock = startBlock; // Handle single block range
}
outBytes = static_cast<size_t>(readBytes);
} else if (compressionType == "xz") {
const size_t outStart = startBlock * bmap.blockSize;
const size_t outEnd = ((endBlock + 1) * bmap.blockSize);

// Initialize the output buffer for the decompressor
lzmaStream.next_out = reinterpret_cast<uint8_t*>(buffer.data());
lzmaStream.avail_out = static_cast<size_t>(buffer.size());

while (outBytes < bufferSize) {
size_t chunkSize = 0;

// Whenever no more input data is available, read some from the compressed file
// and reset the input parameters for the decompressor
if (lzmaStream.avail_in == 0) {
imgFile.read(decBufferIn.data(), static_cast<ssize_t>(decBufferIn.size()));
if (imgFile.gcount() == 0 && imgFile.fail()) {
std::cerr << "Failed to read from xz image file" << std::endl;
close(dev_fd);
imgFile.close();
return 1;
} else {
lzmaStream.next_in = reinterpret_cast<const uint8_t*>(decBufferIn.data());
lzmaStream.avail_in = static_cast<size_t>(imgFile.gcount());
}
}
std::cout << "Processing Range: startBlock=" << startBlock << ", endBlock=" << endBlock << std::endl;

// Save the current status of the output buffer...
chunkSize = lzmaStream.avail_out;
size_t bufferSize = (endBlock - startBlock + 1) * bmap.blockSize;
std::vector<char> buffer(bufferSize);
size_t outBytes = 0;

lzma_ret ret = lzma_code(&lzmaStream, LZMA_RUN);
if (ret != LZMA_OK && ret != LZMA_STREAM_END) {
std::cerr << "Failed to decompress xz image file: " << ret << std::endl;
close(dev_fd);
imgFile.close();
return 1;
if (compressionType == "gzip") {
gzseek(gzImg, static_cast<off_t>(startBlock * bmap.blockSize), SEEK_SET);
int readBytes = gzread(gzImg, buffer.data(), static_cast<unsigned int>(bufferSize));
if (readBytes < 0) {
throw std::string("Failed to read from gzip image file");
}
outBytes = static_cast<size_t>(readBytes);
} else if (compressionType == "xz") {
const size_t outStart = startBlock * bmap.blockSize;
const size_t outEnd = ((endBlock + 1) * bmap.blockSize);

// Initialize the output buffer for the decompressor
lzmaStream.next_out = reinterpret_cast<uint8_t*>(buffer.data());
lzmaStream.avail_out = static_cast<size_t>(buffer.size());

while (outBytes < bufferSize) {
size_t chunkSize = 0;

// Whenever no more input data is available, read some from the compressed file
// and reset the input parameters for the decompressor
if (lzmaStream.avail_in == 0) {
imgFile.read(decBufferIn.data(), static_cast<ssize_t>(decBufferIn.size()));
if (imgFile.gcount() == 0 && imgFile.fail()) {
throw std::string("Failed to read from xz image file");
} else {
lzmaStream.next_in = reinterpret_cast<const uint8_t*>(decBufferIn.data());
lzmaStream.avail_in = static_cast<size_t>(imgFile.gcount());
}
}

// ...and then extract the size of the decompressed chunk
chunkSize -= lzmaStream.avail_out;

if (decHead >= outStart && (decHead + chunkSize) <= outEnd) {
// Case 1: all decoded data can be used
outBytes += chunkSize;
} else if (decHead < outStart && (decHead + chunkSize) <= outStart) {
// Case 2: all decoded data shall be discarded
lzmaStream.next_out = reinterpret_cast<uint8_t*>(buffer.data());
lzmaStream.avail_out = static_cast<size_t>(buffer.size());
} else if (decHead < outStart && (decHead + chunkSize) > outStart) {
// Case 3: only the last portion of the decoded data can be used
std::move(buffer.begin() + static_cast<long int>(outStart - decHead),
buffer.begin() + static_cast<long int>(chunkSize),
buffer.begin());
size_t validData = chunkSize - (outStart - decHead);
outBytes += validData;
lzmaStream.next_out = reinterpret_cast<uint8_t*>(buffer.data()) + validData;
lzmaStream.avail_out = buffer.size() - validData;
}
// Save the current status of the output buffer...
chunkSize = lzmaStream.avail_out;

// Advance the head of the decompressed data
decHead += chunkSize;
lzma_ret ret = lzma_code(&lzmaStream, LZMA_RUN);
if (ret != LZMA_OK && ret != LZMA_STREAM_END) {
throw std::string("Failed to decompress xz image file: ") + std::to_string(static_cast<unsigned int>(ret));
}

// ...and then extract the size of the decompressed chunk
chunkSize -= lzmaStream.avail_out;

if (decHead >= outStart && (decHead + chunkSize) <= outEnd) {
// Case 1: all decoded data can be used
outBytes += chunkSize;
} else if (decHead < outStart && (decHead + chunkSize) <= outStart) {
// Case 2: all decoded data shall be discarded
lzmaStream.next_out = reinterpret_cast<uint8_t*>(buffer.data());
lzmaStream.avail_out = static_cast<size_t>(buffer.size());
} else if (decHead < outStart && (decHead + chunkSize) > outStart) {
// Case 3: only the last portion of the decoded data can be used
std::move(buffer.begin() + static_cast<long int>(outStart - decHead),
buffer.begin() + static_cast<long int>(chunkSize),
buffer.begin());
size_t validData = chunkSize - (outStart - decHead);
outBytes += validData;
lzmaStream.next_out = reinterpret_cast<uint8_t*>(buffer.data()) + validData;
lzmaStream.avail_out = buffer.size() - validData;
}

// In case all the required data has been decompressed OR the XZ stream is ended
// OR the input file has been read completely, stop this decompression loop
if ((lzmaStream.avail_out == 0) || (ret == LZMA_STREAM_END) ||
(lzmaStream.avail_in == 0 && imgFile.eof())) {
break;
// Advance the head of the decompressed data
decHead += chunkSize;

// In case all the required data has been decompressed OR the XZ stream is ended
// OR the input file has been read completely, stop this decompression loop
if ((lzmaStream.avail_out == 0) || (ret == LZMA_STREAM_END) ||
(lzmaStream.avail_in == 0 && imgFile.eof())) {
break;
}
}
} else if (compressionType == "none") {
imgFile.seekg(static_cast<std::streamoff>(startBlock * bmap.blockSize), std::ios::beg);
imgFile.read(buffer.data(), static_cast<std::streamsize>(bufferSize));
outBytes = static_cast<size_t>(imgFile.gcount());
if (outBytes == 0 && imgFile.fail()) {
throw std::string("Failed to read from image file");
}
}
} else if (compressionType == "none") {
imgFile.seekg(static_cast<std::streamoff>(startBlock * bmap.blockSize), std::ios::beg);
imgFile.read(buffer.data(), static_cast<std::streamsize>(bufferSize));
outBytes = static_cast<size_t>(imgFile.gcount());
if (outBytes == 0 && imgFile.fail()) {
std::cerr << "Failed to read from image file" << std::endl;
close(dev_fd);
imgFile.close();
return 1;

// Compute and verify the checksum
std::string computedChecksum = computeSHA256(buffer, outBytes);
if (computedChecksum != range.checksum) {
std::stringstream err;
err << "Checksum verification failed for range: " << range.range << std::endl;
err << "Computed Checksum: " << computedChecksum << std::endl;
err << "Expected Checksum: " << range.checksum;
//std::cerr << "Buffer content (hex):" << std::endl;
//printBufferHex(buffer.data(), outBytes);
throw std::string(err.str());
}
}

// Compute and verify the checksum
char computedChecksum[CHECKSUM_LENGTH + 1];
computeSHA256(buffer.data(), outBytes, computedChecksum);
if (strcmp(computedChecksum, range.checksum.c_str()) != 0) {
std::cerr << "Checksum verification failed for range: " << range.range << std::endl;
std::cerr << "Computed Checksum: " << computedChecksum << std::endl;
std::cerr << "Expected Checksum: " << range.checksum << std::endl;
//std::cerr << "Buffer content (hex):" << std::endl;
//printBufferHex(buffer.data(), outBytes);
close(dev_fd);
if (compressionType == "gzip") {
gzclose(gzImg);
} else if (compressionType == "xz" || compressionType == "none") {
imgFile.close();
if (pwrite(dev_fd, buffer.data(), outBytes, static_cast<off_t>(startBlock * bmap.blockSize)) < 0) {
throw std::string("Write to device failed");
}
return 1;
}

if (pwrite(dev_fd, buffer.data(), outBytes, static_cast<off_t>(startBlock * bmap.blockSize)) < 0) {
std::cerr << "Write to device failed"<< std::endl;
close(dev_fd);
if (compressionType == "gzip") {
gzclose(gzImg);
} else if (compressionType == "xz" || compressionType == "none") {
imgFile.close();
}
return 1;
if (fsync(dev_fd) != 0) {
throw std::string("fsync failed after all writes");
}

std::cout << "Finished writing image to device." << std::endl;
}
catch (std::string& err) {
std::cerr << err << std::endl;
ret = -1;
}

if (fsync(dev_fd) != 0) {
std::cerr << "fsync failed after all writes"<< std::endl;
if (dev_fd >= 0) {
close(dev_fd);
}

if (imgFile.is_open()) {
imgFile.close();
}

close(dev_fd);
if (compressionType == "gzip") {
gzclose(gzImg);
} else if (compressionType == "xz" || compressionType == "none") {
imgFile.close();
} else if (compressionType == "xz") {
lzma_end(&lzmaStream);
}
std::cout << "Finished writing image to device." << std::endl;
return 0;

return ret;
}

int main(int argc, char *argv[]) {
Expand Down
Loading