Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MQTTRetainedMessageManager improvement proposal #1763

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion Samples/Server/Server_Retained_Messages_Samples.cs
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ public static async Task Persist_Retained_Messages()
// used to write all retained messages to dedicated files etc. Then all files must be loaded and a full list
// of retained messages must be provided in the loaded event.

var buffer = JsonSerializer.SerializeToUtf8Bytes(eventArgs.StoredRetainedMessages);
var retainedMessages = await server.GetRetainedMessagesAsync();
var buffer = JsonSerializer.SerializeToUtf8Bytes(retainedMessages);
await File.WriteAllBytesAsync(storePath, buffer);
Console.WriteLine("Retained messages saved.");
}
Expand Down
3 changes: 2 additions & 1 deletion Source/MQTTnet.TestApp/ServerTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,8 @@ public static async Task RunAsync()
Directory.CreateDirectory(directory);
}

File.WriteAllText(Filename, JsonConvert.SerializeObject(e.StoredRetainedMessages));
var retainedMessages = mqttServer.GetRetainedMessagesAsync().GetAwaiter().GetResult();
File.WriteAllText(Filename, JsonConvert.SerializeObject(retainedMessages));
return CompletedTask.Instance;
};

Expand Down
2 changes: 1 addition & 1 deletion Source/MQTTnet.Tests/Server/General.cs
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,7 @@ public async Task Persist_Retained_Message()
var s = await testEnvironment.StartServer();
s.RetainedMessageChangedAsync += e =>
{
savedRetainedMessages = e.StoredRetainedMessages;
savedRetainedMessages = s.GetRetainedMessagesAsync().GetAwaiter().GetResult().ToList();
return CompletedTask.Instance;
};

Expand Down
105 changes: 105 additions & 0 deletions Source/MQTTnet.Tests/Server/Retained_Messages_Tests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using MQTTnet.Client;
using MQTTnet.Formatter;
using MQTTnet.Internal;
using MQTTnet.Packets;
using MQTTnet.Protocol;

Expand Down Expand Up @@ -268,5 +269,109 @@ public async Task Server_Reports_Retained_Messages_Supported_V5()
Assert.IsTrue(connectResult.RetainAvailable);
}
}

[TestMethod]
public async Task Server_Invokes_Retained_Message_Events()
{
using (var testEnvironment = CreateTestEnvironment())
{
const int DefaultPort = 1883;

var mqttFactory = new MqttFactory();
var mqttServerOptions = mqttFactory.CreateServerOptionsBuilder().WithDefaultEndpoint().Build();
var server = mqttFactory.CreateMqttServer(mqttServerOptions);

// subscribe to retained message events

int loadMessagesCount = 0;
int addMessageCount = 0;
int removeMessageCount = 0;
int replaceMessageCount = 0;
int clearMessagesCount = 0;
server.LoadingRetainedMessageAsync += e =>
{
loadMessagesCount++;
return CompletedTask.Instance;
};
server.RetainedMessageChangedAsync += e =>
{
switch (e.ChangeType)
{
case MQTTnet.Server.RetainedMessageChangedEventArgs.RetainedMessageChangeType.Add:
addMessageCount++;
break;
case MQTTnet.Server.RetainedMessageChangedEventArgs.RetainedMessageChangeType.Replace:
replaceMessageCount++;
break;
case MQTTnet.Server.RetainedMessageChangedEventArgs.RetainedMessageChangeType.Remove:
removeMessageCount++;
break;
}
return CompletedTask.Instance;
};
server.RetainedMessagesClearedAsync += e =>
{
clearMessagesCount++;
return CompletedTask.Instance;
};

// start server
await server.StartAsync();

// load-event should be invoked when server starts
await LongTestDelay();
Assert.IsTrue(loadMessagesCount == 1);

// confirm that "load retained messages" event was triggered

var client = testEnvironment.CreateClient();
var connectResult = await client.ConnectAsync(
testEnvironment.Factory.CreateClientOptionsBuilder()
.WithProtocolVersion(MqttProtocolVersion.V500)
.WithTcpServer("127.0.0.1", DefaultPort)
.Build());

Assert.IsTrue(connectResult.RetainAvailable);

// send a message to be retained
var msg1 = new MqttApplicationMessageBuilder().WithTopic("topic1").WithPayload(new byte[] { 1, 2, 3 }).WithRetainFlag(true).Build();
await client.PublishAsync(msg1);

// send a message that is not retained
var msg2 = new MqttApplicationMessageBuilder().WithTopic("topic2").WithPayload(new byte[] { 4, 5, 6 }).Build();
await client.PublishAsync(msg2);

// send another message to be retained
var msg3 = new MqttApplicationMessageBuilder().WithTopic("topic3").WithPayload(new byte[] { 7, 8, 9 }).WithRetainFlag(true).Build();
await client.PublishAsync(msg3);

// two add-messages-events should be received
await LongTestDelay();
Assert.IsTrue(addMessageCount == 2);

// update payload to replace retained message
msg1.PayloadSegment = new System.ArraySegment<byte>(new byte[3] { 3, 2, 1});
await client.PublishAsync(msg1);

await LongTestDelay();
Assert.IsTrue(replaceMessageCount == 1);

// create empty payload to remove retained message
msg1.PayloadSegment = new System.ArraySegment<byte>(new byte[0]);
await client.PublishAsync(msg1);

await LongTestDelay();
Assert.IsTrue(removeMessageCount == 1);

// Deleting all message should result in a clear-messages-event
await server.DeleteRetainedMessagesAsync();

await LongTestDelay();
Assert.IsTrue(clearMessagesCount == 1);

await client.DisconnectAsync();
await server.StopAsync();
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,19 @@ namespace MQTTnet.Server
{
public sealed class RetainedMessageChangedEventArgs : EventArgs
{
public RetainedMessageChangedEventArgs(string clientId, MqttApplicationMessage changedRetainedMessage, List<MqttApplicationMessage> storedRetainedMessages)
public enum RetainedMessageChangeType { Add, Remove, Replace };

public RetainedMessageChangedEventArgs(string clientId,MqttApplicationMessage changedRetainedMessage, RetainedMessageChangeType changeType)
{
ClientId = clientId ?? throw new ArgumentNullException(nameof(clientId));
ChangedRetainedMessage = changedRetainedMessage ?? throw new ArgumentNullException(nameof(changedRetainedMessage));
StoredRetainedMessages = storedRetainedMessages ?? throw new ArgumentNullException(nameof(storedRetainedMessages));
ChangeType = changeType;
}

public MqttApplicationMessage ChangedRetainedMessage { get; }

public string ClientId { get; }

public List<MqttApplicationMessage> StoredRetainedMessages { get; }
public RetainedMessageChangeType ChangeType { get; }
}
}
26 changes: 12 additions & 14 deletions Source/MQTTnet/Server/Internal/MqttRetainedMessagesManager.cs
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,7 @@ public async Task UpdateMessage(string clientId, MqttApplicationMessage applicat

try
{
List<MqttApplicationMessage> messagesForSave = null;
var saveIsRequired = false;
RetainedMessageChangedEventArgs.RetainedMessageChangeType? changeType = null;

lock (_messages)
{
Expand All @@ -74,39 +73,38 @@ public async Task UpdateMessage(string clientId, MqttApplicationMessage applicat

if (!hasPayload)
{
saveIsRequired = _messages.Remove(applicationMessage.Topic);
if (_messages.Remove(applicationMessage.Topic))
{
changeType = RetainedMessageChangedEventArgs.RetainedMessageChangeType.Remove;
}
_logger.Verbose("Client '{0}' cleared retained message for topic '{1}'.", clientId, applicationMessage.Topic);
}
else
{
if (!_messages.TryGetValue(applicationMessage.Topic, out var existingMessage))
{
_messages[applicationMessage.Topic] = applicationMessage;
saveIsRequired = true;
changeType = RetainedMessageChangedEventArgs.RetainedMessageChangeType.Add;
}
else
{
if (existingMessage.QualityOfServiceLevel != applicationMessage.QualityOfServiceLevel || !SequenceEqual(existingMessage.PayloadSegment, payloadSegment))
{
_messages[applicationMessage.Topic] = applicationMessage;
saveIsRequired = true;
changeType = RetainedMessageChangedEventArgs.RetainedMessageChangeType.Replace;
}
}

_logger.Verbose("Client '{0}' set retained message for topic '{1}'.", clientId, applicationMessage.Topic);
}

if (saveIsRequired)
{
messagesForSave = new List<MqttApplicationMessage>(_messages.Values);
}
}

if (saveIsRequired)
if (changeType != null)
{
using (await _storageAccessLock.EnterAsync().ConfigureAwait(false))
{
var eventArgs = new RetainedMessageChangedEventArgs(clientId, applicationMessage, messagesForSave);
var eventArgs = new RetainedMessageChangedEventArgs(clientId, applicationMessage, changeType.Value);

await _eventContainer.RetainedMessageChangedEvent.InvokeAsync(eventArgs).ConfigureAwait(false);
}
}
Expand Down Expand Up @@ -134,9 +132,9 @@ public Task<MqttApplicationMessage> GetMessage(string topic)
{
return Task.FromResult(message);
}

return null;
}

return Task.FromResult<MqttApplicationMessage>(null);
}

public async Task ClearMessages()
Expand Down