From 1ea994d9acb34b997d32085bbe5bbcfc0e5bd5bc Mon Sep 17 00:00:00 2001 From: logicaloud Date: Mon, 12 Jun 2023 23:57:32 +1200 Subject: [PATCH] MQTTRetainedMessageManager improvements --- .../Server_Retained_Messages_Samples.cs | 3 +- Source/MQTTnet.TestApp/ServerTest.cs | 3 +- Source/MQTTnet.Tests/Server/General.cs | 2 +- .../Server/Retained_Messages_Tests.cs | 105 ++++++++++++++++++ .../Events/RetainedMessageChangedEventArgs.cs | 8 +- .../Internal/MqttRetainedMessagesManager.cs | 26 ++--- 6 files changed, 127 insertions(+), 20 deletions(-) diff --git a/Samples/Server/Server_Retained_Messages_Samples.cs b/Samples/Server/Server_Retained_Messages_Samples.cs index e392460f7..e6b534565 100644 --- a/Samples/Server/Server_Retained_Messages_Samples.cs +++ b/Samples/Server/Server_Retained_Messages_Samples.cs @@ -56,7 +56,8 @@ public static async Task Persist_Retained_Messages() // used to write all retained messages to dedicated files etc. Then all files must be loaded and a full list // of retained messages must be provided in the loaded event. - var buffer = JsonSerializer.SerializeToUtf8Bytes(eventArgs.StoredRetainedMessages); + var retainedMessages = await server.GetRetainedMessagesAsync(); + var buffer = JsonSerializer.SerializeToUtf8Bytes(retainedMessages); await File.WriteAllBytesAsync(storePath, buffer); Console.WriteLine("Retained messages saved."); } diff --git a/Source/MQTTnet.TestApp/ServerTest.cs b/Source/MQTTnet.TestApp/ServerTest.cs index bd5961bc5..9486c6052 100644 --- a/Source/MQTTnet.TestApp/ServerTest.cs +++ b/Source/MQTTnet.TestApp/ServerTest.cs @@ -67,7 +67,8 @@ public static async Task RunAsync() Directory.CreateDirectory(directory); } - File.WriteAllText(Filename, JsonConvert.SerializeObject(e.StoredRetainedMessages)); + var retainedMessages = mqttServer.GetRetainedMessagesAsync().GetAwaiter().GetResult(); + File.WriteAllText(Filename, JsonConvert.SerializeObject(retainedMessages)); return CompletedTask.Instance; }; diff --git a/Source/MQTTnet.Tests/Server/General.cs b/Source/MQTTnet.Tests/Server/General.cs index 2e0061e16..fb999dd64 100644 --- a/Source/MQTTnet.Tests/Server/General.cs +++ b/Source/MQTTnet.Tests/Server/General.cs @@ -398,7 +398,7 @@ public async Task Persist_Retained_Message() var s = await testEnvironment.StartServer(); s.RetainedMessageChangedAsync += e => { - savedRetainedMessages = e.StoredRetainedMessages; + savedRetainedMessages = s.GetRetainedMessagesAsync().GetAwaiter().GetResult().ToList(); return CompletedTask.Instance; }; diff --git a/Source/MQTTnet.Tests/Server/Retained_Messages_Tests.cs b/Source/MQTTnet.Tests/Server/Retained_Messages_Tests.cs index b27741575..d7c465da0 100644 --- a/Source/MQTTnet.Tests/Server/Retained_Messages_Tests.cs +++ b/Source/MQTTnet.Tests/Server/Retained_Messages_Tests.cs @@ -7,6 +7,7 @@ using Microsoft.VisualStudio.TestTools.UnitTesting; using MQTTnet.Client; using MQTTnet.Formatter; +using MQTTnet.Internal; using MQTTnet.Packets; using MQTTnet.Protocol; @@ -268,5 +269,109 @@ public async Task Server_Reports_Retained_Messages_Supported_V5() Assert.IsTrue(connectResult.RetainAvailable); } } + + [TestMethod] + public async Task Server_Invokes_Retained_Message_Events() + { + using (var testEnvironment = CreateTestEnvironment()) + { + const int DefaultPort = 1883; + + var mqttFactory = new MqttFactory(); + var mqttServerOptions = mqttFactory.CreateServerOptionsBuilder().WithDefaultEndpoint().Build(); + var server = mqttFactory.CreateMqttServer(mqttServerOptions); + + // subscribe to retained message events + + int loadMessagesCount = 0; + int addMessageCount = 0; + int removeMessageCount = 0; + int replaceMessageCount = 0; + int clearMessagesCount = 0; + server.LoadingRetainedMessageAsync += e => + { + loadMessagesCount++; + return CompletedTask.Instance; + }; + server.RetainedMessageChangedAsync += e => + { + switch (e.ChangeType) + { + case MQTTnet.Server.RetainedMessageChangedEventArgs.RetainedMessageChangeType.Add: + addMessageCount++; + break; + case MQTTnet.Server.RetainedMessageChangedEventArgs.RetainedMessageChangeType.Replace: + replaceMessageCount++; + break; + case MQTTnet.Server.RetainedMessageChangedEventArgs.RetainedMessageChangeType.Remove: + removeMessageCount++; + break; + } + return CompletedTask.Instance; + }; + server.RetainedMessagesClearedAsync += e => + { + clearMessagesCount++; + return CompletedTask.Instance; + }; + + // start server + await server.StartAsync(); + + // load-event should be invoked when server starts + await LongTestDelay(); + Assert.IsTrue(loadMessagesCount == 1); + + // confirm that "load retained messages" event was triggered + + var client = testEnvironment.CreateClient(); + var connectResult = await client.ConnectAsync( + testEnvironment.Factory.CreateClientOptionsBuilder() + .WithProtocolVersion(MqttProtocolVersion.V500) + .WithTcpServer("127.0.0.1", DefaultPort) + .Build()); + + Assert.IsTrue(connectResult.RetainAvailable); + + // send a message to be retained + var msg1 = new MqttApplicationMessageBuilder().WithTopic("topic1").WithPayload(new byte[] { 1, 2, 3 }).WithRetainFlag(true).Build(); + await client.PublishAsync(msg1); + + // send a message that is not retained + var msg2 = new MqttApplicationMessageBuilder().WithTopic("topic2").WithPayload(new byte[] { 4, 5, 6 }).Build(); + await client.PublishAsync(msg2); + + // send another message to be retained + var msg3 = new MqttApplicationMessageBuilder().WithTopic("topic3").WithPayload(new byte[] { 7, 8, 9 }).WithRetainFlag(true).Build(); + await client.PublishAsync(msg3); + + // two add-messages-events should be received + await LongTestDelay(); + Assert.IsTrue(addMessageCount == 2); + + // update payload to replace retained message + msg1.PayloadSegment = new System.ArraySegment(new byte[3] { 3, 2, 1}); + await client.PublishAsync(msg1); + + await LongTestDelay(); + Assert.IsTrue(replaceMessageCount == 1); + + // create empty payload to remove retained message + msg1.PayloadSegment = new System.ArraySegment(new byte[0]); + await client.PublishAsync(msg1); + + await LongTestDelay(); + Assert.IsTrue(removeMessageCount == 1); + + // Deleting all message should result in a clear-messages-event + await server.DeleteRetainedMessagesAsync(); + + await LongTestDelay(); + Assert.IsTrue(clearMessagesCount == 1); + + await client.DisconnectAsync(); + await server.StopAsync(); + } + } } } \ No newline at end of file diff --git a/Source/MQTTnet/Server/Events/RetainedMessageChangedEventArgs.cs b/Source/MQTTnet/Server/Events/RetainedMessageChangedEventArgs.cs index c7beb8076..0e93a3371 100644 --- a/Source/MQTTnet/Server/Events/RetainedMessageChangedEventArgs.cs +++ b/Source/MQTTnet/Server/Events/RetainedMessageChangedEventArgs.cs @@ -9,17 +9,19 @@ namespace MQTTnet.Server { public sealed class RetainedMessageChangedEventArgs : EventArgs { - public RetainedMessageChangedEventArgs(string clientId, MqttApplicationMessage changedRetainedMessage, List storedRetainedMessages) + public enum RetainedMessageChangeType { Add, Remove, Replace }; + + public RetainedMessageChangedEventArgs(string clientId,MqttApplicationMessage changedRetainedMessage, RetainedMessageChangeType changeType) { ClientId = clientId ?? throw new ArgumentNullException(nameof(clientId)); ChangedRetainedMessage = changedRetainedMessage ?? throw new ArgumentNullException(nameof(changedRetainedMessage)); - StoredRetainedMessages = storedRetainedMessages ?? throw new ArgumentNullException(nameof(storedRetainedMessages)); + ChangeType = changeType; } public MqttApplicationMessage ChangedRetainedMessage { get; } public string ClientId { get; } - public List StoredRetainedMessages { get; } + public RetainedMessageChangeType ChangeType { get; } } } \ No newline at end of file diff --git a/Source/MQTTnet/Server/Internal/MqttRetainedMessagesManager.cs b/Source/MQTTnet/Server/Internal/MqttRetainedMessagesManager.cs index 4dc0b2c06..30211efeb 100644 --- a/Source/MQTTnet/Server/Internal/MqttRetainedMessagesManager.cs +++ b/Source/MQTTnet/Server/Internal/MqttRetainedMessagesManager.cs @@ -64,8 +64,7 @@ public async Task UpdateMessage(string clientId, MqttApplicationMessage applicat try { - List messagesForSave = null; - var saveIsRequired = false; + RetainedMessageChangedEventArgs.RetainedMessageChangeType? changeType = null; lock (_messages) { @@ -74,7 +73,10 @@ public async Task UpdateMessage(string clientId, MqttApplicationMessage applicat if (!hasPayload) { - saveIsRequired = _messages.Remove(applicationMessage.Topic); + if (_messages.Remove(applicationMessage.Topic)) + { + changeType = RetainedMessageChangedEventArgs.RetainedMessageChangeType.Remove; + } _logger.Verbose("Client '{0}' cleared retained message for topic '{1}'.", clientId, applicationMessage.Topic); } else @@ -82,31 +84,27 @@ public async Task UpdateMessage(string clientId, MqttApplicationMessage applicat if (!_messages.TryGetValue(applicationMessage.Topic, out var existingMessage)) { _messages[applicationMessage.Topic] = applicationMessage; - saveIsRequired = true; + changeType = RetainedMessageChangedEventArgs.RetainedMessageChangeType.Add; } else { if (existingMessage.QualityOfServiceLevel != applicationMessage.QualityOfServiceLevel || !SequenceEqual(existingMessage.PayloadSegment, payloadSegment)) { _messages[applicationMessage.Topic] = applicationMessage; - saveIsRequired = true; + changeType = RetainedMessageChangedEventArgs.RetainedMessageChangeType.Replace; } } _logger.Verbose("Client '{0}' set retained message for topic '{1}'.", clientId, applicationMessage.Topic); } - - if (saveIsRequired) - { - messagesForSave = new List(_messages.Values); - } } - if (saveIsRequired) + if (changeType != null) { using (await _storageAccessLock.EnterAsync().ConfigureAwait(false)) { - var eventArgs = new RetainedMessageChangedEventArgs(clientId, applicationMessage, messagesForSave); + var eventArgs = new RetainedMessageChangedEventArgs(clientId, applicationMessage, changeType.Value); + await _eventContainer.RetainedMessageChangedEvent.InvokeAsync(eventArgs).ConfigureAwait(false); } } @@ -134,9 +132,9 @@ public Task GetMessage(string topic) { return Task.FromResult(message); } - - return null; } + + return Task.FromResult(null); } public async Task ClearMessages()