diff --git a/internal/data/acls.go b/internal/data/acls.go index 482baa34..65180041 100644 --- a/internal/data/acls.go +++ b/internal/data/acls.go @@ -14,6 +14,7 @@ import ( "github.com/NHAS/wag/pkg/control" clientv3 "go.etcd.io/etcd/client/v3" "go.etcd.io/etcd/client/v3/clientv3util" + "golang.org/x/exp/maps" ) func SetAcl(effects string, policy acls.Acl, overwrite bool) error { @@ -76,17 +77,37 @@ func RemoveAcl(effects string) error { return err } +func insertMap(m map[string]bool, values ...string) { + for _, v := range values { + m[v] = true + } +} + func GetEffectiveAcl(username string) acls.Acl { - var resultingACLs acls.Acl - //Add the server address by default - resultingACLs.Allow = []string{config.Values.Wireguard.ServerAddress.String() + "/32"} + + var ( + // Do deduplication for multiple acls + allowSet = map[string]bool{} + mfaSet = map[string]bool{} + denySet = map[string]bool{} + ) + + insertMap(allowSet, config.Values.Wireguard.ServerAddress.String()+"/32") txn := etcd.Txn(context.Background()) txn.Then(clientv3.OpGet("wag-acls-*"), clientv3.OpGet("wag-acls-"+username), clientv3.OpGet(MembershipKey+"-"+username), clientv3.OpGet(dnsKey)) resp, err := txn.Commit() if err != nil { log.Println("failed to get policy data for user", username, "err:", err) - return acls.Acl{} + return acls.Acl{ + Allow: []string{config.Values.Wireguard.ServerAddress.String() + "/32"}, + } + } + + addAcls := func(acl acls.Acl) { + insertMap(allowSet, acl.Allow...) + insertMap(mfaSet, acl.Mfa...) + insertMap(denySet, acl.Deny...) } // the default policy contents @@ -95,8 +116,7 @@ func GetEffectiveAcl(username string) acls.Acl { err := json.Unmarshal(resp.Responses[0].GetResponseRange().Kvs[0].Value, &acl) if err == nil { - resultingACLs.Allow = append(resultingACLs.Allow, acl.Allow...) - resultingACLs.Mfa = append(resultingACLs.Mfa, acl.Mfa...) + addAcls(acl) } else { RaiseError(err, []byte("failed to unmarshal default acls policy")) log.Println("failed to unmarshal default acls policy: ", err) @@ -109,8 +129,7 @@ func GetEffectiveAcl(username string) acls.Acl { err := json.Unmarshal(resp.Responses[1].GetResponseRange().Kvs[0].Value, &acl) if err == nil { - resultingACLs.Allow = append(resultingACLs.Allow, acl.Allow...) - resultingACLs.Mfa = append(resultingACLs.Mfa, acl.Mfa...) + addAcls(acl) } else { log.Println("failed to unmarshal user specific acls: ", err) } @@ -148,9 +167,7 @@ func GetEffectiveAcl(username string) acls.Acl { log.Println("failed to unmarshal acl from response: ", err, string(r.Kvs[0].Value)) continue } - - resultingACLs.Allow = append(resultingACLs.Allow, acl.Allow...) - resultingACLs.Mfa = append(resultingACLs.Mfa, acl.Mfa...) + addAcls(acl) } } @@ -167,12 +184,18 @@ func GetEffectiveAcl(username string) acls.Acl { err = json.Unmarshal(resp.Responses[3].GetResponseRange().Kvs[0].Value, &dns) if err == nil { for _, server := range dns { - resultingACLs.Allow = append(resultingACLs.Allow, fmt.Sprintf("%s 53/any", server)) + insertMap(allowSet, fmt.Sprintf("%s 53/any", server)) } } else { log.Println("failed to unmarshal dns setting: ", err) } } + resultingACLs := acls.Acl{ + Allow: maps.Keys(allowSet), + Mfa: maps.Keys(mfaSet), + Deny: maps.Keys(denySet), + } + return resultingACLs } diff --git a/pkg/control/server/config.go b/pkg/control/server/config.go index f78aebf9..cdfc6fe1 100644 --- a/pkg/control/server/config.go +++ b/pkg/control/server/config.go @@ -51,7 +51,6 @@ func editPolicy(w http.ResponseWriter, r *http.Request) { if err := json.NewDecoder(r.Body).Decode(&polciyData); err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return - } if err := data.SetAcl(polciyData.Effects, acls.Acl{Mfa: polciyData.MfaRoutes, Allow: polciyData.PublicRoutes, Deny: polciyData.DenyRoutes}, true); err != nil {