Skip to content

Commit

Permalink
feat(sql storage): [BREAKING CHANGES] : Fix "prune signals after sent…
Browse files Browse the repository at this point in the history
…" feature and modify database structure (#12)

* feat(sql): Handle signal prune and delete on cascade

* [BREAKING CHANGE] database: Use signal_id foreign key in source table

[BREAKING CHANGE] database: Use delete on cascade when possible

Fix: fix prune signal after sent feature

* style(*): Pass through black linter
  • Loading branch information
julienloizelet authored Feb 2, 2024
1 parent 8750981 commit be6779a
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 21 deletions.
18 changes: 18 additions & 0 deletions examples/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,24 @@
machine_id=generate_machine_id_from_key("myMachineKeyIdentifier"),
context=[{"key": "scenario-version", "value": "1.0.0"}],
message="test message to see where it is written",
decisions=[
{
"origin": "crowdsec",
"duration": "1h",
"scenario": "crowdsec/ssh-bf",
"scope": "ip",
"type": "ban",
"value": "81.81.81.81",
},
{
"origin": "pysdk",
"duration": "2h",
"scenario": "crowdsec/ssh-bf",
"scope": "ip",
"type": "ban",
"value": "81.81.81.81",
},
],
)
]

Expand Down
60 changes: 39 additions & 21 deletions src/cscapi/sql_storage.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import asdict
from typing import List
from typing import List, Optional

from dacite import from_dict
from sqlalchemy import (
Expand All @@ -12,6 +12,7 @@
create_engine,
delete,
update,
event,
)
from sqlalchemy.orm import (
DeclarativeBase,
Expand All @@ -21,9 +22,23 @@
sessionmaker,
)

from sqlalchemy.engine import Engine
from cscapi import storage


"""
By default, foreign key constraints are disabled in SQLite.
@see https://docs.sqlalchemy.org/en/20/dialects/sqlite.html#foreign-key-support
"""


@event.listens_for(Engine, "connect")
def set_sqlite_pragma(dbapi_connection, connection_record):
cursor = dbapi_connection.cursor()
cursor.execute("PRAGMA foreign_keys=ON")
cursor.close()


class Base(DeclarativeBase):
def to_dict(self):
return {c.name: getattr(self, c.name) for c in self.__table__.columns}
Expand All @@ -33,7 +48,7 @@ class MachineDBModel(Base):
__tablename__ = "machine_models"

id = Column(Integer, primary_key=True, autoincrement=True)
machine_id = Column(TEXT)
machine_id = Column(TEXT, unique=True)
token = Column(TEXT)
password = Column(TEXT)
scenarios = Column(TEXT)
Expand All @@ -54,7 +69,7 @@ class DecisionDBModel(Base):
type = Column(TEXT)
value = Column(TEXT)
signal_id: Mapped[int] = mapped_column(
"signal_id", ForeignKey("signal_models.alert_id")
"signal_id", ForeignKey("signal_models.alert_id", ondelete="CASCADE")
)


Expand All @@ -71,6 +86,9 @@ class SourceDBModel(Base):
value = Column(TEXT)
as_name = Column(TEXT)
longitude = Column(Float)
signal_id = Column(
Integer, ForeignKey("signal_models.alert_id", ondelete="CASCADE")
)


class ContextDBModel(Base):
Expand All @@ -80,7 +98,7 @@ class ContextDBModel(Base):
value = Column(TEXT)
key = Column(TEXT)
signal_id: Mapped[int] = mapped_column(
"signal_id", ForeignKey("signal_models.alert_id")
"signal_id", ForeignKey("signal_models.alert_id", ondelete="CASCADE")
)


Expand All @@ -100,8 +118,6 @@ class SignalDBModel(Base):
stop_at = Column(TEXT, nullable=True)
sent = Column(Boolean, default=False)

source_id = Column(Integer, ForeignKey("source_models.id"), nullable=True)

context: Mapped[List["ContextDBModel"]] = relationship(
"ContextDBModel", backref="signal"
)
Expand Down Expand Up @@ -133,29 +149,29 @@ def get_all_signals(self) -> List[storage.SignalModel]:
for res in self.session.query(SignalDBModel).all()
]

def get_machine_by_id(self, machine_id: str) -> storage.MachineModel:
exisiting = (
def get_machine_by_id(self, machine_id: str) -> Optional[storage.MachineModel]:
existing = (
self.session.query(MachineDBModel)
.filter(MachineDBModel.machine_id == machine_id)
.first()
)
if not exisiting:
return
if not existing:
return None
return storage.MachineModel(
machine_id=exisiting.machine_id,
token=exisiting.token,
password=exisiting.password,
scenarios=exisiting.scenarios,
is_failing=exisiting.is_failing,
machine_id=existing.machine_id,
token=existing.token,
password=existing.password,
scenarios=existing.scenarios,
is_failing=existing.is_failing,
)

def update_or_create_machine(self, machine: storage.MachineModel) -> bool:
exisiting = (
existing = (
self.session.query(MachineDBModel)
.filter(MachineDBModel.machine_id == machine.machine_id)
.all()
)
if not exisiting:
if not existing:
self.session.add(MachineDBModel(**asdict(machine)))
self.session.commit()
return True
Expand Down Expand Up @@ -193,19 +209,19 @@ def update_or_create_signal(self, signal: storage.SignalModel) -> bool:
for dec in signal.decisions
]

exisiting = (
existing = (
self.session.query(SignalDBModel)
.filter(SignalDBModel.alert_id == signal.alert_id)
.first()
)

if not exisiting:
if not existing:
self.session.add(to_insert)
self.session.commit()
return True

for c in to_insert.__table__.columns:
setattr(exisiting, c.name, getattr(to_insert, c.name))
setattr(existing, c.name, getattr(to_insert, c.name))

self.session.commit()
return False
Expand All @@ -215,9 +231,11 @@ def delete_signals(self, signals: List[storage.SignalModel]):
SignalDBModel.alert_id.in_((signal.alert_id for signal in signals))
)
self.session.execute(stmt)
self.session.commit()

def delete_machines(self, machines: List[storage.MachineModel]):
stmt = delete(MachineDBModel).where(
MachineDBModel.machine_id in ([machine.machine_id for machine in machines])
MachineDBModel.machine_id.in_((machine.machine_id for machine in machines))
)
self.session.execute(stmt)
self.session.commit()

0 comments on commit be6779a

Please sign in to comment.