Skip to content

Commit

Permalink
[MLA-1783] built-in actuator type (#4950)
Browse files Browse the repository at this point in the history
  • Loading branch information
Chris Elion authored Feb 24, 2021
1 parent ad620ec commit 60d5d93
Show file tree
Hide file tree
Showing 15 changed files with 225 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ namespace Unity.MLAgents.Extensions.Input
/// <see cref="Agent"/>'s <see cref="BehaviorParameters"/> indicate that the Agent is running in Heuristic Mode,
/// this Actuator will write actions from the <see cref="InputSystem"/> to the <see cref="ActionBuffers"/> object.
/// </summary>
public class InputActionActuator : IActuator, IHeuristicProvider
public class InputActionActuator : IActuator, IHeuristicProvider, IBuiltInActuator
{
readonly BehaviorParameters m_BehaviorParameters;
readonly InputAction m_Action;
Expand All @@ -35,8 +35,8 @@ public class InputActionActuator : IActuator, IHeuristicProvider
/// <param name="adaptor">The <see cref="IRLActionInputAdaptor"/> that will convert data between ML-Agents
/// and the <see cref="InputSystem"/>.</param>
public InputActionActuator(InputDevice inputDevice, BehaviorParameters behaviorParameters,
InputAction action,
IRLActionInputAdaptor adaptor)
InputAction action,
IRLActionInputAdaptor adaptor)
{
m_BehaviorParameters = behaviorParameters;
Name = $"InputActionActuator-{action.name}";
Expand Down Expand Up @@ -83,6 +83,12 @@ public void Heuristic(in ActionBuffers actionBuffersOut)
m_InputAdaptor.WriteToHeuristic(m_Action, actionBuffersOut);
Profiler.EndSample();
}

/// <inheritdoc/>
public BuiltInActuatorType GetBuiltInActuatorType()
{
return BuiltInActuatorType.InputActionActuator;
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ namespace Unity.MLAgents.Extensions.Match3
/// Actuator for a Match3 game. It translates valid moves (defined by AbstractBoard.IsMoveValid())
/// in action masks, and applies the action to the board via AbstractBoard.MakeMove().
/// </summary>
public class Match3Actuator : IActuator, IHeuristicProvider
public class Match3Actuator : IActuator, IHeuristicProvider, IBuiltInActuator
{
protected AbstractBoard m_Board;
protected System.Random m_Random;
Expand Down Expand Up @@ -92,6 +92,12 @@ public void ResetData()
{
}

/// <inheritdoc/>
public BuiltInActuatorType GetBuiltInActuatorType()
{
return BuiltInActuatorType.Match3Actuator;
}

IEnumerable<int> InvalidMoveIndices()
{
var numValidMoves = m_Board.NumMoves();
Expand Down Expand Up @@ -179,6 +185,5 @@ protected virtual int EvalMovePoints(Move move)
{
return 1;
}

}
}
49 changes: 49 additions & 0 deletions com.unity.ml-agents/Runtime/Actuators/IBuiltInActuator.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
namespace Unity.MLAgents.Actuators
{
/// <summary>
/// Identifiers for "built in" actuator types.
/// These are only used for analytics, and should not be used for any runtime decisions.
///
/// NOTE: Do not renumber these, since the values are used for analytics. Renaming is allowed though.
/// </summary>
public enum BuiltInActuatorType
{
/// <summary>
/// Default Sensor type if it cannot be determined.
/// </summary>
Unknown = 0,

/// <summary>
/// VectorActuator used by the Agent
/// </summary>
AgentVectorActuator = 1,

/// <summary>
/// Corresponds to <see cref="VectorActuator"/>
/// </summary>
VectorActuator = 2,

/// <summary>
/// Corresponds to the Match3Actuator in com.unity.ml-agents.extensions.
/// </summary>
Match3Actuator = 3,

/// <summary>
/// Corresponds to the InputActionActuator in com.unity.ml-agents.extensions.
/// </summary>
InputActionActuator = 4,
}

/// <summary>
/// Interface for actuators that are provided as part of ML-Agents.
/// User-implemented actuators don't need to use this interface.
/// </summary>
internal interface IBuiltInActuator
{
/// <summary>
/// Return the corresponding BuiltInActuatorType for the actuator.
/// </summary>
/// <returns>A BuiltInActuatorType corresponding to the actuator.</returns>
BuiltInActuatorType GetBuiltInActuatorType();
}
}

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 7 additions & 1 deletion com.unity.ml-agents/Runtime/Actuators/VectorActuator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ namespace Unity.MLAgents.Actuators
/// <summary>
/// IActuator implementation that forwards calls to an <see cref="IActionReceiver"/> and an <see cref="IHeuristicProvider"/>.
/// </summary>
internal class VectorActuator : IActuator, IHeuristicProvider
internal class VectorActuator : IActuator, IHeuristicProvider, IBuiltInActuator
{
IActionReceiver m_ActionReceiver;
IHeuristicProvider m_HeuristicProvider;
Expand Down Expand Up @@ -95,5 +95,11 @@ public void WriteDiscreteActionMask(IDiscreteActionMask actionMask)

/// <inheritdoc />
public string Name { get; }

/// <inheritdoc />
public virtual BuiltInActuatorType GetBuiltInActuatorType()
{
return BuiltInActuatorType.VectorActuator;
}
}
}
21 changes: 20 additions & 1 deletion com.unity.ml-agents/Runtime/Agent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,25 @@ public void CopyActions(ActionBuffers actionBuffers)
}
}

/// <summary>
/// Simple wrapper around VectorActuator that overrides GetBuiltInActuatorType
/// so that it can be distinguished from a standard VectorActuator.
/// </summary>
internal class AgentVectorActuator : VectorActuator
{
public AgentVectorActuator(IActionReceiver actionReceiver,
IHeuristicProvider heuristicProvider,
ActionSpec actionSpec,
string name = "VectorActuator"
) : base(actionReceiver, heuristicProvider, actionSpec, name)
{ }

public override BuiltInActuatorType GetBuiltInActuatorType()
{
return BuiltInActuatorType.AgentVectorActuator;
}
}

/// <summary>
/// An agent is an actor that can observe its environment, decide on the
/// best course of action using those observations, and execute those actions
Expand Down Expand Up @@ -997,7 +1016,7 @@ void InitializeActuators()
// Support legacy OnActionReceived
// TODO don't set this up if the sizes are 0?
var param = m_PolicyFactory.BrainParameters;
m_VectorActuator = new VectorActuator(this, this, param.ActionSpec);
m_VectorActuator = new AgentVectorActuator(this, this, param.ActionSpec);
m_ActuatorManager = new ActuatorManager(attachedActuators.Length + 1);
m_LegacyActionCache = new float[m_VectorActuator.TotalNumberOfActions()];
m_LegacyHeuristicCache = new float[m_VectorActuator.TotalNumberOfActions()];
Expand Down
31 changes: 31 additions & 0 deletions com.unity.ml-agents/Runtime/Analytics/Events.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ internal struct InferenceEvent
public int InferenceDevice;
public List<EventObservationSpec> ObservationSpecs;
public EventActionSpec ActionSpec;
public List<EventActuatorInfo> ActuatorInfos;
public int MemorySize;
public long TotalWeightSizeBytes;
public string ModelHash;
Expand Down Expand Up @@ -48,6 +49,35 @@ public static EventActionSpec FromActionSpec(ActionSpec actionSpec)
}
}

/// <summary>
/// Information about an actuator.
/// </summary>
[Serializable]
internal struct EventActuatorInfo
{
public int BuiltInActuatorType;
public int NumContinuousActions;
public int NumDiscreteActions;

public static EventActuatorInfo FromActuator(IActuator actuator)
{
BuiltInActuatorType builtInActuatorType = Actuators.BuiltInActuatorType.Unknown;
if (actuator is IBuiltInActuator builtInActuator)
{
builtInActuatorType = builtInActuator.GetBuiltInActuatorType();
}

var actionSpec = actuator.ActionSpec;

return new EventActuatorInfo
{
BuiltInActuatorType = (int)builtInActuatorType,
NumContinuousActions = actionSpec.NumContinuousActions,
NumDiscreteActions = actionSpec.NumDiscreteActions
};
}
}

/// <summary>
/// Information about one dimension of an observation.
/// </summary>
Expand Down Expand Up @@ -101,6 +131,7 @@ internal struct RemotePolicyInitializedEvent
public string BehaviorName;
public List<EventObservationSpec> ObservationSpecs;
public EventActionSpec ActionSpec;
public List<EventActuatorInfo> ActuatorInfos;

/// <summary>
/// This will be the same as TrainingEnvironmentInitializedEvent if available, but
Expand Down
18 changes: 14 additions & 4 deletions com.unity.ml-agents/Runtime/Analytics/InferenceAnalytics.cs
Original file line number Diff line number Diff line change
Expand Up @@ -87,13 +87,15 @@ public static bool IsAnalyticsEnabled()
/// <param name="inferenceDevice">Whether inference is being performed on the CPU or GPU</param>
/// <param name="sensors">List of ISensors for the Agent. Used to generate information about the observation space.</param>
/// <param name="actionSpec">ActionSpec for the Agent. Used to generate information about the action space.</param>
/// <param name="actuators">List of IActuators for the Agent. Used to generate information about the action space.</param>
/// <returns></returns>
public static void InferenceModelSet(
NNModel nnModel,
string behaviorName,
InferenceDevice inferenceDevice,
IList<ISensor> sensors,
ActionSpec actionSpec
ActionSpec actionSpec,
IList<IActuator> actuators
)
{
// The event shouldn't be able to report if this is disabled but if we know we're not going to report
Expand All @@ -112,9 +114,9 @@ ActionSpec actionSpec
return;
}

var data = GetEventForModel(nnModel, behaviorName, inferenceDevice, sensors, actionSpec);
var data = GetEventForModel(nnModel, behaviorName, inferenceDevice, sensors, actionSpec, actuators);
// Note - to debug, use JsonUtility.ToJson on the event.
//Debug.Log(JsonUtility.ToJson(data, true));
// Debug.Log(JsonUtility.ToJson(data, true));
#if UNITY_EDITOR
if (AnalyticsUtils.s_SendEditorAnalytics)
{
Expand All @@ -133,13 +135,15 @@ ActionSpec actionSpec
/// <param name="inferenceDevice"></param>
/// <param name="sensors"></param>
/// <param name="actionSpec"></param>
/// <param name="actuators"></param>
/// <returns></returns>
internal static InferenceEvent GetEventForModel(
NNModel nnModel,
string behaviorName,
InferenceDevice inferenceDevice,
IList<ISensor> sensors,
ActionSpec actionSpec
ActionSpec actionSpec,
IList<IActuator> actuators
)
{
var barracudaModel = ModelLoader.Load(nnModel);
Expand Down Expand Up @@ -175,6 +179,12 @@ ActionSpec actionSpec
inferenceEvent.ObservationSpecs.Add(EventObservationSpec.FromSensor(sensor));
}

inferenceEvent.ActuatorInfos = new List<EventActuatorInfo>(actuators.Count);
foreach (var actuator in actuators)
{
inferenceEvent.ActuatorInfos.Add(EventActuatorInfo.FromActuator(actuator));
}

inferenceEvent.TotalWeightSizeBytes = GetModelWeightSize(barracudaModel);
inferenceEvent.ModelHash = GetModelHash(barracudaModel);
return inferenceEvent;
Expand Down
17 changes: 13 additions & 4 deletions com.unity.ml-agents/Runtime/Analytics/TrainingAnalytics.cs
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,8 @@ public static void TrainingEnvironmentInitialized(TrainingEnvironmentInitialized
public static void RemotePolicyInitialized(
string fullyQualifiedBehaviorName,
IList<ISensor> sensors,
ActionSpec actionSpec
ActionSpec actionSpec,
IList<IActuator> actuators
)
{
if (!IsAnalyticsEnabled())
Expand All @@ -158,7 +159,7 @@ ActionSpec actionSpec
return;
}

var data = GetEventForRemotePolicy(behaviorName, sensors, actionSpec);
var data = GetEventForRemotePolicy(behaviorName, sensors, actionSpec, actuators);
// Note - to debug, use JsonUtility.ToJson on the event.
// Debug.Log(
// $"Would send event {k_RemotePolicyInitializedEventName} with body {JsonUtility.ToJson(data, true)}"
Expand Down Expand Up @@ -220,10 +221,12 @@ public static void TrainingBehaviorInitialized(TrainingBehaviorInitializedEvent
#endif
}

static RemotePolicyInitializedEvent GetEventForRemotePolicy(
internal static RemotePolicyInitializedEvent GetEventForRemotePolicy(
string behaviorName,
IList<ISensor> sensors,
ActionSpec actionSpec)
ActionSpec actionSpec,
IList<IActuator> actuators
)
{
var remotePolicyEvent = new RemotePolicyInitializedEvent();

Expand All @@ -238,6 +241,12 @@ static RemotePolicyInitializedEvent GetEventForRemotePolicy(
remotePolicyEvent.ObservationSpecs.Add(EventObservationSpec.FromSensor(sensor));
}

remotePolicyEvent.ActuatorInfos = new List<EventActuatorInfo>(actuators.Count);
foreach (var actuator in actuators)
{
remotePolicyEvent.ActuatorInfos.Add(EventActuatorInfo.FromActuator(actuator));
}

remotePolicyEvent.MLAgentsEnvsVersion = s_TrainerPackageVersion;
remotePolicyEvent.TrainerCommunicationVersion = s_TrainerCommunicationVersion;
return remotePolicyEvent;
Expand Down
10 changes: 9 additions & 1 deletion com.unity.ml-agents/Runtime/Policies/BarracudaPolicy.cs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,11 @@ internal class BarracudaPolicy : IPolicy

private string m_BehaviorName;

/// <summary>
/// List of actuators, only used for analytics
/// </summary>
private IList<IActuator> m_Actuators;

/// <summary>
/// Whether or not we've tried to send analytics for this model. We only ever try to send once per policy,
/// and do additional deduplication in the analytics code.
Expand All @@ -57,6 +62,7 @@ internal class BarracudaPolicy : IPolicy
/// <inheritdoc />
public BarracudaPolicy(
ActionSpec actionSpec,
IList<IActuator> actuators,
NNModel model,
InferenceDevice inferenceDevice,
string behaviorName
Expand All @@ -66,6 +72,7 @@ string behaviorName
m_ModelRunner = modelRunner;
m_BehaviorName = behaviorName;
m_ActionSpec = actionSpec;
m_Actuators = actuators;
}

/// <inheritdoc />
Expand All @@ -79,7 +86,8 @@ public void RequestDecision(AgentInfo info, List<ISensor> sensors)
m_BehaviorName,
m_ModelRunner.InferenceDevice,
sensors,
m_ActionSpec
m_ActionSpec,
m_Actuators
);
}
m_AgentId = info.episodeId;
Expand Down
6 changes: 3 additions & 3 deletions com.unity.ml-agents/Runtime/Policies/BehaviorParameters.cs
Original file line number Diff line number Diff line change
Expand Up @@ -221,16 +221,16 @@ internal IPolicy GeneratePolicy(ActionSpec actionSpec, ActuatorManager actuatorM
"Either assign a model, or change to a different Behavior Type."
);
}
return new BarracudaPolicy(actionSpec, m_Model, m_InferenceDevice, m_BehaviorName);
return new BarracudaPolicy(actionSpec, actuatorManager, m_Model, m_InferenceDevice, m_BehaviorName);
}
case BehaviorType.Default:
if (Academy.Instance.IsCommunicatorOn)
{
return new RemotePolicy(actionSpec, FullyQualifiedBehaviorName);
return new RemotePolicy(actionSpec, actuatorManager, FullyQualifiedBehaviorName);
}
if (m_Model != null)
{
return new BarracudaPolicy(actionSpec, m_Model, m_InferenceDevice, m_BehaviorName);
return new BarracudaPolicy(actionSpec, actuatorManager, m_Model, m_InferenceDevice, m_BehaviorName);
}
else
{
Expand Down
Loading

0 comments on commit 60d5d93

Please sign in to comment.