Skip to content

Commit

Permalink
Merge pull request #5 from WallaceIT/err_cleanup
Browse files Browse the repository at this point in the history
Simplify the code a bit using more C++ facilities
  • Loading branch information
embetrix authored Dec 15, 2024
2 parents fd1661b + f33b5b7 commit 1525385
Showing 1 changed file with 145 additions and 161 deletions.
306 changes: 145 additions & 161 deletions bmap-writer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

#include <iostream>
#include <fstream>
#include <sstream>
#include <vector>
#include <iomanip>
#include <string>
Expand Down Expand Up @@ -104,21 +105,24 @@ bmap_t parseBMap(const std::string &filename) {
return bmapData;
}

void computeSHA256(const char *buffer, size_t size, char *output) {
std::string computeSHA256(const std::vector<char>& buffer, size_t size) {
EVP_MD_CTX *mdctx;
unsigned char hash[EVP_MAX_MD_SIZE];
unsigned int hash_len;

mdctx = EVP_MD_CTX_new();
EVP_DigestInit_ex(mdctx, EVP_sha256(), NULL);
EVP_DigestUpdate(mdctx, buffer, size);
EVP_DigestUpdate(mdctx, buffer.data(), size);
EVP_DigestFinal_ex(mdctx, hash, &hash_len);
EVP_MD_CTX_free(mdctx);

std::ostringstream output;
output << std::hex;
for (unsigned int i = 0; i < hash_len; ++i) {
sprintf(output + (i * 2), "%02x", hash[i]);
output << std::setfill('0') << std::setw(2) << static_cast<unsigned int>(hash[i]);
}
output[CHECKSUM_LENGTH] = 0;

return output.str();
}

int getCompressionType(const std::string &imageFile, std::string &compressionType) {
Expand Down Expand Up @@ -169,197 +173,177 @@ void printBufferHex(const char *buffer, size_t size) {
}

int BmapWriteImage(const std::string &imageFile, const bmap_t &bmap, const std::string &device, const std::string &compressionType) {
int dev_fd = open(device.c_str(), O_WRONLY | O_CREAT, S_IRUSR | S_IWUSR);
if (dev_fd < 0) {
std::cerr << "Unable to open or create target device" << std::endl;
return 1;
}

gzFile gzImg = nullptr;
lzma_stream lzmaStream = LZMA_STREAM_INIT;
std::vector<char> decBufferIn(DEC_BUFFER_SIZE);
size_t decHead = 0;
std::ifstream imgFile;
int dev_fd = -1;
int ret = 0;

if (compressionType == "gzip") {
gzImg = gzopen(imageFile.c_str(), "rb");
if (!gzImg) {
std::cerr << "Unable to open gzip image file" << std::endl;
close(dev_fd);
return 1;
}
} else if (compressionType == "xz") {
imgFile.open(imageFile, std::ios::binary);
if (!imgFile) {
std::cerr << "Unable to open xz image file" << std::endl;
close(dev_fd);
return 1;
}
lzma_ret ret = lzma_stream_decoder(&lzmaStream, UINT64_MAX, 0);
if (ret != LZMA_OK) {
std::cerr << "Failed to initialize lzma decoder: " << ret << std::endl;
close(dev_fd);
return 1;
try {
dev_fd = open(device.c_str(), O_WRONLY | O_CREAT, S_IRUSR | S_IWUSR);
if (dev_fd < 0) {
throw std::string("Unable to open or create target device");
}

lzmaStream.avail_in = 0;
} else if (compressionType == "none") {
imgFile.open(imageFile, std::ios::binary);
if (!imgFile) {
std::cerr << "Unable to open image file" << std::endl;
close(dev_fd);
return 1;
}
} else {
std::cerr << "Unsupported compression type" << std::endl;
close(dev_fd);
return 1;
}
if (compressionType == "gzip") {
gzImg = gzopen(imageFile.c_str(), "rb");
if (!gzImg) {
throw std::string("Unable to open gzip image file");
}
} else if (compressionType == "xz") {
imgFile.open(imageFile, std::ios::binary);
if (!imgFile) {
throw std::string("Unable to open xz image file");
}
lzma_ret ret = lzma_stream_decoder(&lzmaStream, UINT64_MAX, 0);
if (ret != LZMA_OK) {
throw std::string("Failed to initialize lzma decoder: ") + std::to_string(static_cast<unsigned int>(ret));
}

for (const auto &range : bmap.ranges) {
size_t startBlock, endBlock;
if (sscanf(range.range.c_str(), "%zu-%zu", &startBlock, &endBlock) == 1) {
endBlock = startBlock; // Handle single block range
lzmaStream.avail_in = 0;
} else if (compressionType == "none") {
imgFile.open(imageFile, std::ios::binary);
if (!imgFile) {
throw std::string("Unable to open image file");
}
} else {
throw std::string("Unsupported compression type ") + compressionType;
}
std::cout << "Processing Range: startBlock=" << startBlock << ", endBlock=" << endBlock << std::endl;

size_t bufferSize = (endBlock - startBlock + 1) * bmap.blockSize;
std::vector<char> buffer(bufferSize);
size_t outBytes = 0;

if (compressionType == "gzip") {
gzseek(gzImg, static_cast<off_t>(startBlock * bmap.blockSize), SEEK_SET);
int readBytes = gzread(gzImg, buffer.data(), static_cast<unsigned int>(bufferSize));
if (readBytes < 0) {
std::cerr << "Failed to read from gzip image file" << std::endl;
close(dev_fd);
gzclose(gzImg);
return 1;
for (const auto &range : bmap.ranges) {
size_t startBlock, endBlock;
if (sscanf(range.range.c_str(), "%zu-%zu", &startBlock, &endBlock) == 1) {
endBlock = startBlock; // Handle single block range
}
outBytes = static_cast<size_t>(readBytes);
} else if (compressionType == "xz") {
const size_t outStart = startBlock * bmap.blockSize;
const size_t outEnd = ((endBlock + 1) * bmap.blockSize);

// Initialize the output buffer for the decompressor
lzmaStream.next_out = reinterpret_cast<uint8_t*>(buffer.data());
lzmaStream.avail_out = static_cast<size_t>(buffer.size());

while (outBytes < bufferSize) {
size_t chunkSize = 0;

// Whenever no more input data is available, read some from the compressed file
// and reset the input parameters for the decompressor
if (lzmaStream.avail_in == 0) {
imgFile.read(decBufferIn.data(), static_cast<ssize_t>(decBufferIn.size()));
if (imgFile.gcount() == 0 && imgFile.fail()) {
std::cerr << "Failed to read from xz image file" << std::endl;
close(dev_fd);
imgFile.close();
return 1;
} else {
lzmaStream.next_in = reinterpret_cast<const uint8_t*>(decBufferIn.data());
lzmaStream.avail_in = static_cast<size_t>(imgFile.gcount());
}
}
std::cout << "Processing Range: startBlock=" << startBlock << ", endBlock=" << endBlock << std::endl;

// Save the current status of the output buffer...
chunkSize = lzmaStream.avail_out;
size_t bufferSize = (endBlock - startBlock + 1) * bmap.blockSize;
std::vector<char> buffer(bufferSize);
size_t outBytes = 0;

lzma_ret ret = lzma_code(&lzmaStream, LZMA_RUN);
if (ret != LZMA_OK && ret != LZMA_STREAM_END) {
std::cerr << "Failed to decompress xz image file: " << ret << std::endl;
close(dev_fd);
imgFile.close();
return 1;
if (compressionType == "gzip") {
gzseek(gzImg, static_cast<off_t>(startBlock * bmap.blockSize), SEEK_SET);
int readBytes = gzread(gzImg, buffer.data(), static_cast<unsigned int>(bufferSize));
if (readBytes < 0) {
throw std::string("Failed to read from gzip image file");
}
outBytes = static_cast<size_t>(readBytes);
} else if (compressionType == "xz") {
const size_t outStart = startBlock * bmap.blockSize;
const size_t outEnd = ((endBlock + 1) * bmap.blockSize);

// Initialize the output buffer for the decompressor
lzmaStream.next_out = reinterpret_cast<uint8_t*>(buffer.data());
lzmaStream.avail_out = static_cast<size_t>(buffer.size());

while (outBytes < bufferSize) {
size_t chunkSize = 0;

// Whenever no more input data is available, read some from the compressed file
// and reset the input parameters for the decompressor
if (lzmaStream.avail_in == 0) {
imgFile.read(decBufferIn.data(), static_cast<ssize_t>(decBufferIn.size()));
if (imgFile.gcount() == 0 && imgFile.fail()) {
throw std::string("Failed to read from xz image file");
} else {
lzmaStream.next_in = reinterpret_cast<const uint8_t*>(decBufferIn.data());
lzmaStream.avail_in = static_cast<size_t>(imgFile.gcount());
}
}

// ...and then extract the size of the decompressed chunk
chunkSize -= lzmaStream.avail_out;

if (decHead >= outStart && (decHead + chunkSize) <= outEnd) {
// Case 1: all decoded data can be used
outBytes += chunkSize;
} else if (decHead < outStart && (decHead + chunkSize) <= outStart) {
// Case 2: all decoded data shall be discarded
lzmaStream.next_out = reinterpret_cast<uint8_t*>(buffer.data());
lzmaStream.avail_out = static_cast<size_t>(buffer.size());
} else if (decHead < outStart && (decHead + chunkSize) > outStart) {
// Case 3: only the last portion of the decoded data can be used
std::move(buffer.begin() + static_cast<long int>(outStart - decHead),
buffer.begin() + static_cast<long int>(chunkSize),
buffer.begin());
size_t validData = chunkSize - (outStart - decHead);
outBytes += validData;
lzmaStream.next_out = reinterpret_cast<uint8_t*>(buffer.data()) + validData;
lzmaStream.avail_out = buffer.size() - validData;
}
// Save the current status of the output buffer...
chunkSize = lzmaStream.avail_out;

// Advance the head of the decompressed data
decHead += chunkSize;
lzma_ret ret = lzma_code(&lzmaStream, LZMA_RUN);
if (ret != LZMA_OK && ret != LZMA_STREAM_END) {
throw std::string("Failed to decompress xz image file: ") + std::to_string(static_cast<unsigned int>(ret));
}

// ...and then extract the size of the decompressed chunk
chunkSize -= lzmaStream.avail_out;

if (decHead >= outStart && (decHead + chunkSize) <= outEnd) {
// Case 1: all decoded data can be used
outBytes += chunkSize;
} else if (decHead < outStart && (decHead + chunkSize) <= outStart) {
// Case 2: all decoded data shall be discarded
lzmaStream.next_out = reinterpret_cast<uint8_t*>(buffer.data());
lzmaStream.avail_out = static_cast<size_t>(buffer.size());
} else if (decHead < outStart && (decHead + chunkSize) > outStart) {
// Case 3: only the last portion of the decoded data can be used
std::move(buffer.begin() + static_cast<long int>(outStart - decHead),
buffer.begin() + static_cast<long int>(chunkSize),
buffer.begin());
size_t validData = chunkSize - (outStart - decHead);
outBytes += validData;
lzmaStream.next_out = reinterpret_cast<uint8_t*>(buffer.data()) + validData;
lzmaStream.avail_out = buffer.size() - validData;
}

// In case all the required data has been decompressed OR the XZ stream is ended
// OR the input file has been read completely, stop this decompression loop
if ((lzmaStream.avail_out == 0) || (ret == LZMA_STREAM_END) ||
(lzmaStream.avail_in == 0 && imgFile.eof())) {
break;
// Advance the head of the decompressed data
decHead += chunkSize;

// In case all the required data has been decompressed OR the XZ stream is ended
// OR the input file has been read completely, stop this decompression loop
if ((lzmaStream.avail_out == 0) || (ret == LZMA_STREAM_END) ||
(lzmaStream.avail_in == 0 && imgFile.eof())) {
break;
}
}
} else if (compressionType == "none") {
imgFile.seekg(static_cast<std::streamoff>(startBlock * bmap.blockSize), std::ios::beg);
imgFile.read(buffer.data(), static_cast<std::streamsize>(bufferSize));
outBytes = static_cast<size_t>(imgFile.gcount());
if (outBytes == 0 && imgFile.fail()) {
throw std::string("Failed to read from image file");
}
}
} else if (compressionType == "none") {
imgFile.seekg(static_cast<std::streamoff>(startBlock * bmap.blockSize), std::ios::beg);
imgFile.read(buffer.data(), static_cast<std::streamsize>(bufferSize));
outBytes = static_cast<size_t>(imgFile.gcount());
if (outBytes == 0 && imgFile.fail()) {
std::cerr << "Failed to read from image file" << std::endl;
close(dev_fd);
imgFile.close();
return 1;

// Compute and verify the checksum
std::string computedChecksum = computeSHA256(buffer, outBytes);
if (computedChecksum != range.checksum) {
std::stringstream err;
err << "Checksum verification failed for range: " << range.range << std::endl;
err << "Computed Checksum: " << computedChecksum << std::endl;
err << "Expected Checksum: " << range.checksum;
//std::cerr << "Buffer content (hex):" << std::endl;
//printBufferHex(buffer.data(), outBytes);
throw std::string(err.str());
}
}

// Compute and verify the checksum
char computedChecksum[CHECKSUM_LENGTH + 1];
computeSHA256(buffer.data(), outBytes, computedChecksum);
if (strcmp(computedChecksum, range.checksum.c_str()) != 0) {
std::cerr << "Checksum verification failed for range: " << range.range << std::endl;
std::cerr << "Computed Checksum: " << computedChecksum << std::endl;
std::cerr << "Expected Checksum: " << range.checksum << std::endl;
//std::cerr << "Buffer content (hex):" << std::endl;
//printBufferHex(buffer.data(), outBytes);
close(dev_fd);
if (compressionType == "gzip") {
gzclose(gzImg);
} else if (compressionType == "xz" || compressionType == "none") {
imgFile.close();
if (pwrite(dev_fd, buffer.data(), outBytes, static_cast<off_t>(startBlock * bmap.blockSize)) < 0) {
throw std::string("Write to device failed");
}
return 1;
}

if (pwrite(dev_fd, buffer.data(), outBytes, static_cast<off_t>(startBlock * bmap.blockSize)) < 0) {
std::cerr << "Write to device failed"<< std::endl;
close(dev_fd);
if (compressionType == "gzip") {
gzclose(gzImg);
} else if (compressionType == "xz" || compressionType == "none") {
imgFile.close();
}
return 1;
if (fsync(dev_fd) != 0) {
throw std::string("fsync failed after all writes");
}

std::cout << "Finished writing image to device." << std::endl;
}
catch (std::string& err) {
std::cerr << err << std::endl;
ret = -1;
}

if (fsync(dev_fd) != 0) {
std::cerr << "fsync failed after all writes"<< std::endl;
if (dev_fd >= 0) {
close(dev_fd);
}

if (imgFile.is_open()) {
imgFile.close();
}

close(dev_fd);
if (compressionType == "gzip") {
gzclose(gzImg);
} else if (compressionType == "xz" || compressionType == "none") {
imgFile.close();
} else if (compressionType == "xz") {
lzma_end(&lzmaStream);
}
std::cout << "Finished writing image to device." << std::endl;
return 0;

return ret;
}

int main(int argc, char *argv[]) {
Expand Down

0 comments on commit 1525385

Please sign in to comment.