Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to set default value to a dataclass field? #58

Open
barenko opened this issue Oct 27, 2024 · 1 comment
Open

How to set default value to a dataclass field? #58

barenko opened this issue Oct 27, 2024 · 1 comment

Comments

@barenko
Copy link

barenko commented Oct 27, 2024

I want to set a None value in Optional dataclass fields. But I was unable to find how to do it in documentation. This is possible?

I have a table like:

create table t(
   id serial primary key, 
   name varchar not null, 
   description text null
)

running sqlc generate I got:

@dataclasses.dataclass()
class T:
   id: int
   name: str
   description: Optional[str]

But what I really wants is a default None value in description, like:

@dataclasses.dataclass()
class T:
   id: int
   name: str
   description: Optional[str] = None

I'm using sqlc v1.27.0 (by docker) and generating python code for postgres with the config:

version: '2'
plugins:
- name: py
  wasm:
    url: https://downloads.sqlc.dev/plugin/sqlc-gen-python_1.2.0.wasm
    sha256: a6c5d174c407007c3717eea36ff0882744346e6ba991f92f71d6ab2895204c0e
sql:
- schema: "schema.sql"
  queries: "queries.sql"
  engine: postgresql
  codegen:
  - out: ./model
    plugin: py
    options:
      package: model
      emit_sync_querier: true
      emit_async_querier: true
      emit_empty_slices: true
@barenko
Copy link
Author

barenko commented Jan 12, 2025

I made a working around to continue to use the sqlc on python:

  • After run sqlc generate, execute a sed on the generated file to add default values:
    sed -i '/^[[:space:]]*[a-zA-Z0-9_]*:[[:space:]]*Optional/ s/\(Optional\[\([^]]*\)\]\)/\1 = None/' models.py
  • With the defaults, we need to re-sort all fields to put the fields with default as the lasts:
import re
import logging
import sys

logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s')

def main(filename: str):
    with open(filename, "r") as file:
        contents = file.read().split("\n")

    dataclass_line=-1
    class_line=-1
    inner_class_indent=''
    field_map={}
    for idx, line in enumerate(contents):
        if re.match(r"\@dataclasses.dataclass\(\)", line):
            dataclass_line = idx
            logging.debug(f"Found dataclass line: {dataclass_line}")
            continue
        if dataclass_line >= 0:
            if re.match(r"class ", line):
                class_line = idx
                logging.debug(f"Found class line: {class_line}")
                continue
        if class_line >= 0:
            if inner_class_indent == '':
                inner_class_indent = re.match(r"^(\s+)", line).group(1)
            if re.match("^"+inner_class_indent+r"[\S|\n]", line):
                if re.match(r"\s+[a-z]+[a-z_0-9]*\s*:\s*", line):
                    logging.debug(f"Found field line: {line}")
                    field_map[idx] = line
            else:
                if field_map.keys():
                    resort_fields(class_line, idx, contents, field_map)
                class_line = -1
                dataclass_line = -1
                inner_class_indent = ''
                field_map = {}

    with open(filename, 'w') as file:
        file.write("\n".join(contents))

def resort_fields(class_line, end_line, contents, field_map):
    logging.info(f"Resorting fields from: {contents[class_line]}")
    def sort(kv):
        k = kv[1]
        if "Optional" in k and " = " in k:
            return "Z"+k
        if " = " in k:
            return "Y"+k
        if " = None" in k:
            return "X"+k
        else:
            return "W"+k
        
    sorted_fields = sorted(field_map.items(), key=sort)

    logging.debug("BEFORE:")
    for line in contents[class_line:end_line]:
        logging.debug(line)

    original_idxs = field_map.keys()
    sorted_idxs =  [ s[0] for s in sorted_fields]
    tuples = list(zip(original_idxs, sorted_idxs))

    logging.debug("DE->PARA:")
    for o,d in tuples:
        logging.debug(f"{field_map.get(o)}{o} -> {field_map.get(d)}{d}")
        contents[o] = field_map.get(d)
        
    logging.info("AFTER:")
    for line in contents[class_line:end_line]:
        logging.info(line)


if __name__ == "__main__":
    if len(sys.argv) != 2:
        logging.error("Usage: my_script <file_to_modify.py>")
        sys.exit(1)

    filename = sys.argv[1]
    main(filename)
python -W all scripts/resort_dataclass_optional_fields.py models.py

You will need to do this on models and queries.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant