Skip to content

Commit

Permalink
Merge pull request #17 from aau-network-security/develop
Browse files Browse the repository at this point in the history
Release v1.0.2
  • Loading branch information
Mikkelhost authored Apr 24, 2024
2 parents a0a9840 + 273ac27 commit 60d52b6
Show file tree
Hide file tree
Showing 5 changed files with 109 additions and 18 deletions.
30 changes: 17 additions & 13 deletions internal/daemon/adminAgents.go
Original file line number Diff line number Diff line change
Expand Up @@ -175,9 +175,9 @@ func (d *daemon) getAgents(c *gin.Context) {

admin, err := d.getUserFromGinContext(c)
if err != nil {
log.Error().Err(err).Msg("error getting user from gin context")
c.JSON(http.StatusInternalServerError, APIResponse{Status: "Internal Server Error"})
return
log.Error().Err(err).Msg("error getting user from gin context")
c.JSON(http.StatusInternalServerError, APIResponse{Status: "Internal Server Error"})
return
}
d.auditLogger.Info().
Time("UTC", time.Now().UTC()).
Expand Down Expand Up @@ -249,9 +249,9 @@ func (d *daemon) deleteAgent(c *gin.Context) {

admin, err := d.getUserFromGinContext(c)
if err != nil {
log.Error().Err(err).Msg("error getting user from gin context")
c.JSON(http.StatusInternalServerError, APIResponse{Status: "Internal Server Error"})
return
log.Error().Err(err).Msg("error getting user from gin context")
c.JSON(http.StatusInternalServerError, APIResponse{Status: "Internal Server Error"})
return
}
d.auditLogger.Info().
Time("UTC", time.Now().UTC()).
Expand Down Expand Up @@ -383,14 +383,18 @@ func (d *daemon) reconnectAgent(c *gin.Context) {
}

d.agentPool.addAgent(agentForPool)

d.eventpool.M.RLock()
for _, event := range d.eventpool.Events {
event.M.Lock()
for _, lab := range event.Labs {

if lab.ParentAgent.Name == agentForPool.Name {
lab.Conn = conn
}
}
event.M.Unlock()
}
d.eventpool.M.RUnlock()

c.JSON(http.StatusOK, APIResponse{Status: "OK"})
return
Expand All @@ -402,9 +406,9 @@ func (d *daemon) reconnectAgent(c *gin.Context) {
func (d *daemon) lockAgentState(c *gin.Context) {
admin, err := d.getUserFromGinContext(c)
if err != nil {
log.Error().Err(err).Msg("error getting user from gin context")
c.JSON(http.StatusInternalServerError, APIResponse{Status: "Internal Server Error"})
return
log.Error().Err(err).Msg("error getting user from gin context")
c.JSON(http.StatusInternalServerError, APIResponse{Status: "Internal Server Error"})
return
}
d.auditLogger.Info().
Time("UTC", time.Now().UTC()).
Expand Down Expand Up @@ -439,9 +443,9 @@ func (d *daemon) lockAgentState(c *gin.Context) {
func (d *daemon) unlockAgentState(c *gin.Context) {
admin, err := d.getUserFromGinContext(c)
if err != nil {
log.Error().Err(err).Msg("error getting user from gin context")
c.JSON(http.StatusInternalServerError, APIResponse{Status: "Internal Server Error"})
return
log.Error().Err(err).Msg("error getting user from gin context")
c.JSON(http.StatusInternalServerError, APIResponse{Status: "Internal Server Error"})
return
}
d.auditLogger.Info().
Time("UTC", time.Now().UTC()).
Expand Down
4 changes: 4 additions & 0 deletions internal/daemon/adminEvents.go
Original file line number Diff line number Diff line change
Expand Up @@ -332,9 +332,11 @@ func (d *daemon) getEvents(c *gin.Context) {
if dbEvent.Status == 0 {
event, err := d.eventpool.GetEvent(dbEvent.Tag)
if err == nil {
event.M.RLock()
for _, lab := range event.Labs {
labs = append(labs, lab)
}
event.M.RUnlock()
}
}

Expand Down Expand Up @@ -384,9 +386,11 @@ func (d *daemon) getEvents(c *gin.Context) {
if dbEvent.Status == 0 {
event, err := d.eventpool.GetEvent(dbEvent.Tag)
if err == nil {
event.M.RLock()
for _, lab := range event.Labs {
labs = append(labs, lab)
}
event.M.RUnlock()
}
}

Expand Down
71 changes: 70 additions & 1 deletion internal/daemon/agentpool.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,9 @@ func (ap *AgentPool) connectToMonitoringStream(routineCtx context.Context, a *Ag
LabInfo: l,
}
event.UnassignedVpnLabs <- agentLab
event.M.Lock()
event.Labs[l.Tag] = agentLab
event.M.Unlock()
saveState(eventPool, statePath)
continue
}
Expand All @@ -128,7 +130,9 @@ func (ap *AgentPool) connectToMonitoringStream(routineCtx context.Context, a *Ag
LabInfo: l,
}
event.UnassignedBrowserLabs <- agentLab
event.M.Lock()
event.Labs[l.Tag] = agentLab
event.M.Unlock()
saveState(eventPool, statePath)
continue
}
Expand Down Expand Up @@ -658,7 +662,7 @@ func (d *daemon) agentSyncRoutine(ticker *time.Ticker) {
}
}(client)
}

}
}
}
Expand All @@ -671,3 +675,68 @@ func (agent *Agent) GetName() string {

return agent.Name
}

func (d *daemon) agentReconnectionRoutine(ticker *time.Ticker) {
log.Info().Msg("[agent-reconnection-routine] starting routine")
for range ticker.C {
ctx := context.Background()
dbAgents, err := d.db.GetAgents(ctx)
if err != nil {
log.Error().Err(err).Msg("[agent-reconnection-routine] error getting agents from database")
continue
}

for _, dbAgent := range dbAgents {
if _, err := d.agentPool.getAgent(dbAgent.Name); err != nil {
log.Debug().Str("agent", dbAgent.Name).Msg("[agent-reconnection-routine] agent not found in agentpool, reconnecting...")
serviceConf := ServiceConfig{
Grpc: dbAgent.Url,
AuthKey: dbAgent.AuthKey,
SignKey: dbAgent.SignKey,
TLSEnabled: dbAgent.Tls,
}
conn, memoryInstalled, err := NewAgentConnection(serviceConf)
if err != nil {
log.Error().Err(err).Msg("error reconnecting to agent")
continue
}

streamCtx, cancel := context.WithCancel(context.Background())
agentForPool := &Agent{
M: sync.RWMutex{},
Name: dbAgent.Name,
Url: dbAgent.Url,
Tls: dbAgent.Tls,
Conn: conn,
Weight: dbAgent.Weight,
RequestsLeft: dbAgent.Weight,
StateLock: false,
Errors: []error{},
Close: cancel,
Resources: AgentResources{
MemoryInstalled: memoryInstalled,
},
}

if err := d.agentPool.connectToStreams(streamCtx, agentForPool, d.eventpool, d.conf.StatePath); err != nil {
log.Error().Err(err).Msg("error connecting to agent streams")
continue
}

d.agentPool.addAgent(agentForPool)
d.eventpool.M.RLock()
for _, event := range d.eventpool.Events {
event.M.Lock()
for _, lab := range event.Labs {

if lab.ParentAgent.Name == agentForPool.Name {
lab.Conn = conn
}
}
event.M.Unlock()
}
d.eventpool.M.RUnlock()
}
}
}
}
11 changes: 11 additions & 0 deletions internal/daemon/daemon.go
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,9 @@ func (d *daemon) Run() error {
agentSyncRoutineTicker := time.NewTicker(30 * time.Second)
go d.agentSyncRoutine(agentSyncRoutineTicker)

agentReconnectionTicker := time.NewTicker(10 * time.Second)
go d.agentReconnectionRoutine(agentReconnectionTicker)

listeningAddress := fmt.Sprintf("%s:%d", d.conf.ListeningIp, d.conf.Port)
return r.Run(listeningAddress)
}
Expand All @@ -363,7 +366,9 @@ func (d *daemon) labExpiryRoutine() {
for _, event := range d.eventpool.Events {
var wg sync.WaitGroup
anyLabsClosed := false
event.M.RLock()
for _, team := range event.Teams {
team.M.RLock()
if team.Lab != nil {
if time.Now().After(team.Lab.ExpiresAtTime) {
if team.Lab.Conn != nil {
Expand All @@ -372,8 +377,12 @@ func (d *daemon) labExpiryRoutine() {
go func(team *Team, event *Event) {
defer wg.Done()
defer func() {
event.M.Lock()
delete(event.Labs, team.Lab.LabInfo.Tag)
event.M.Unlock()
team.M.Lock()
team.Lab = nil
team.M.Unlock()
saveState(d.eventpool, d.conf.StatePath)
sendCommandToTeam(team, updateTeam)
}()
Expand All @@ -389,7 +398,9 @@ func (d *daemon) labExpiryRoutine() {
}
}
}
team.M.RUnlock()
}
event.M.RUnlock()
wg.Wait()
if anyLabsClosed {
broadCastCommandToEventTeams(event, updateEventInfo)
Expand Down
11 changes: 7 additions & 4 deletions internal/daemon/eventpool.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ func (ep *EventPool) GetEvent(eventTag string) (*Event, error) {
return event, nil
}

func (ep *EventPool) GetAllEvents() (map[string]*Event) {
func (ep *EventPool) GetAllEvents() map[string]*Event {
ep.M.RLock()
defer ep.M.RUnlock()

Expand All @@ -63,11 +63,13 @@ func (ep *EventPool) GetAllAgentLabsForAgent(agentName string) []*AgentLab {

var labsForAgent []*AgentLab
for _, event := range ep.Events {
event.M.RLock()
for _, lab := range event.Labs {
if lab.ParentAgent.Name == agentName {
labsForAgent = append(labsForAgent, lab)
}
}
event.M.RUnlock()
}

return labsForAgent
Expand Down Expand Up @@ -108,10 +110,12 @@ func (event *Event) AddTeam(team *Team) {

// Calculates the current amount of labs for an event then checks if it has passed or equal to the configured amount of maximum labs for event
func (event *Event) IsMaxLabsReached() bool {
event.M.RLock()
defer event.M.RUnlock()
// First get amount of teams waiting for labs
currentNumberOfLabs := event.TeamsWaitingForBrowserLabs.Len() + event.TeamsWaitingForVpnLabs.Len()
for _, team := range event.Teams {
if team.Status == WaitingForLab || team.Status == InQueue{
if team.Status == WaitingForLab || team.Status == InQueue {
currentNumberOfLabs += 1
}
}
Expand All @@ -138,7 +142,7 @@ func (event *Event) IsMaxLabsReached() bool {
This had the unfortunate effect of spending 1 core on the CPU per event created...
Short minded fix is currently inserting a 1 milisecond delay...
*/
func (event *Event) startQueueHandlers(eventPool *EventPool, statePath string, labExpiry time.Duration ) {
func (event *Event) startQueueHandlers(eventPool *EventPool, statePath string, labExpiry time.Duration) {
browserQueueHandler := func() {
log.Debug().Msg("Waiting for teams to enter browser lab queue")
for {
Expand Down Expand Up @@ -251,4 +255,3 @@ func (team *Team) LockForFunc(function func()) {
}

// Lab

0 comments on commit 60d52b6

Please sign in to comment.