Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for overriding pg schema #396

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/build_and_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ jobs:
POSTGRES_PASSWORD: testing
POSTGRES_DB: envio-dev
POSTGRES_USER: postgres
POSTGRES_SCHEMA: public
# Set health checks to wait until postgres has started
options: >-
--health-cmd pg_isready
Expand Down
45 changes: 45 additions & 0 deletions codegenerator/cli/npm/envio/src/bindings/Postgres.res
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ type poolConfig = {
database?: string, // Name of database to connect to (default: '')
username?: string, // Username of database user (default: '')
password?: string, // Password of database user (default: '')
schema?: string, // Name of schema to connect to (default: 'public')
ssl?: sslOptions, // true, prefer, require, tls.connect options (default: false)
max?: int, // Max number of connections (default: 10)
maxLifetime?: option<int>, // Max lifetime in seconds (more info below) (default: null)
Expand All @@ -86,9 +87,53 @@ type poolConfig = {
fetchTypes?: bool, // Automatically fetches types on connect on initial connection. (default: true)
}

let makeConnectionString = (config: poolConfig) => {
let parts = ["postgres://"]

switch (config.username, config.password) {
| (Some(username), Some(password)) => parts->Js.Array2.push(`${username}:${password}@`)->ignore
| (Some(username), None) => parts->Js.Array2.push(`${username}@`)->ignore
| _ => ()
}

switch config.host {
| Some(host) => parts->Js.Array2.push(host)->ignore
| None => ()
}

switch config.port {
| Some(port) => parts->Js.Array2.push(`:${port->Belt.Int.toString}`)->ignore
| None => ()
}

switch config.database {
| Some(database) => parts->Js.Array2.push(`/${database}`)->ignore
| None => ()
}

switch config.schema {
| Some(schema) => parts->Js.Array2.push(`?search_path=${schema}`)->ignore
| None => ()
}

let connectionString = parts->Js.Array2.joinWith("")
connectionString
}

@module
external makeSql: (~config: poolConfig) => sql = "postgres"

@module
external makeSqlWithConnectionString: (string, poolConfig) => sql = "postgres"

let makeSql = (~config: poolConfig) => {
let connectionString = makeConnectionString(config)

Js.Console.log(`Connection string: ${connectionString}`)

makeSqlWithConnectionString(connectionString, config)
}

@send external beginSql: (sql, sql => array<promise<unit>>) => promise<unit> = "begin"

// TODO: can explore this approach (https://forum.rescript-lang.org/t/rfc-support-for-tagged-template-literals/3744)
Expand Down
4 changes: 2 additions & 2 deletions codegenerator/cli/npm/envio/src/db/EntityHistory.res
Original file line number Diff line number Diff line change
Expand Up @@ -243,8 +243,8 @@ let fromTable = (table: table, ~schema: S.t<'entity>): t<'entity> => {

let insertFnName = `"insert_${table.tableName}"`
let historyRowArg = "history_row"
let historyTablePath = `"public"."${historyTableName}"`
let originTablePath = `"public"."${originTableName}"`
let historyTablePath = `"${historyTableName}"`
let originTablePath = `"${originTableName}"`

let previousHistoryFieldsAreNullStr =
previousChangeFieldNames
Expand Down
2 changes: 1 addition & 1 deletion codegenerator/cli/npm/envio/src/db/Table.res
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ module PostgresInterop = {
table->getNonDefaultFieldNames->Array.map(fieldName => `"${fieldName}"`)
`(sql, rows) => {
return sql\`
INSERT INTO "public"."${table.tableName}"
INSERT INTO "${table.tableName}"
\${sql(rows, ${fieldNamesInQuotes->Js.Array2.joinWith(", ")})}
ON CONFLICT(${table->getPrimaryKeyFieldNames->Js.Array2.joinWith(", ")}) DO UPDATE
SET
Expand Down
35 changes: 29 additions & 6 deletions codegenerator/cli/src/persisted_state/db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,33 @@ async fn get_pg_pool() -> Result<PgPool, sqlx::Error> {
let user = get_env_with_default("ENVIO_PG_USER", "postgres");
let password = get_env_with_default("ENVIO_POSTGRES_PASSWORD", "testing");
let database = get_env_with_default("ENVIO_PG_DATABASE", "envio-dev");

let connection_url = format!("postgres://{user}:{password}@{host}:{port}/{database}");

PgPoolOptions::new().connect(&connection_url).await
let schema = get_env_with_default("ENVIO_PG_SCHEMA", "public");

let connection_url = format!(
"postgresql://{}:{}@{}:{}/{}",
user, password, host, port, database
);

let pool = PgPoolOptions::new()
.max_connections(5)
.after_connect({
let schema = schema.clone();
move |conn, _| {
Box::pin({
let value = schema.clone();
async move {
sqlx::query(&format!("SET search_path TO {}", value))
.execute(conn)
.await?;
Ok(())
}
})
}
})
.connect(&connection_url)
.await?;

Ok(pool)
}

impl PersistedState {
Expand All @@ -27,7 +50,7 @@ impl PersistedState {
async fn upsert_to_db_with_pool(&self, pool: &PgPool) -> Result<PgQueryResult, sqlx::Error> {
sqlx::query(
r#"
INSERT INTO public.persisted_state (
INSERT INTO persisted_state (
id,
envio_version,
config_hash,
Expand Down Expand Up @@ -77,7 +100,7 @@ impl PersistedStateExists {
schema_hash,
handler_files_hash,
abi_files_hash
from public.persisted_state WHERE id = 1",
from persisted_state WHERE id = 1",
)
.fetch_optional(pool)
.await;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ services:
POSTGRES_PASSWORD: ${ENVIO_POSTGRES_PASSWORD:-testing}
POSTGRES_USER: ${ENVIO_PG_USER:-postgres}
POSTGRES_DB: ${ENVIO_PG_DATABASE:-envio-dev}
POSTGRES_SCHEMA: ${ENVIO_PG_SCHEMA:-public_123}
networks:
- my-proxy-net
graphql-engine:
Expand Down
1 change: 1 addition & 0 deletions codegenerator/cli/templates/static/codegen/src/Env.res
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ module Db = {
let user = envSafe->EnvSafe.get("ENVIO_PG_USER", S.string, ~devFallback="postgres")
let password = envSafe->EnvSafe.get("ENVIO_POSTGRES_PASSWORD", S.string, ~devFallback="testing")
let database = envSafe->EnvSafe.get("ENVIO_PG_DATABASE", S.string, ~devFallback="envio-dev")
let schema = envSafe->EnvSafe.get("ENVIO_PG_SCHEMA", S.string, ~fallback="public")
let ssl = envSafe->EnvSafe.get(
"ENVIO_PG_SSL_MODE",
Postgres.sslOptionsSchema,
Expand Down
1 change: 1 addition & 0 deletions codegenerator/cli/templates/static/codegen/src/db/Db.res
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ let config: Postgres.poolConfig = {
username: Env.Db.user,
password: Env.Db.password,
database: Env.Db.database,
schema: Env.Db.schema,
ssl: Env.Db.ssl,
// TODO: think how we want to pipe these logs to pino.
onnotice: ?(Env.userLogLevel == #warn || Env.userLogLevel == #error ? None : Some(_str => ())),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ module General = {
type existsRes = {exists: bool}

let hasRows = async (sql, ~table: Table.table) => {
let query = `SELECT EXISTS(SELECT 1 FROM public."${table.tableName}");`
let query = `SELECT EXISTS(SELECT 1 FROM "${table.tableName}");`
switch await sql->Postgres.unsafe(query) {
| [{exists}] => exists
| _ => Js.Exn.raiseError("Unexpected result from hasRows query: " ++ query)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ const batchSetItemsInTableCore = (table, sql, rowDataArray) => {
);

return sql`
INSERT INTO "public".${sql(table.tableName)}
INSERT INTO ${sql(table.tableName)}
${sql(rowDataArray, ...fieldNames)}
ON CONFLICT(${sql`${commaSeparateDynamicMapQuery(
sql,
Expand All @@ -56,7 +56,7 @@ module.exports.batchDeleteItemsInTable = (table, sql, pkArray) => {
if (primaryKeyFieldNames.length === 1) {
return sql`
DELETE
FROM "public".${sql(table.tableName)}
FROM ${sql(table.tableName)}
WHERE ${sql(primaryKeyFieldNames[0])} IN ${sql(pkArray)};
`;
} else {
Expand All @@ -71,7 +71,7 @@ module.exports.batchReadItemsInTable = (table, sql, pkArray) => {
if (primaryKeyFieldNames.length === 1) {
return sql`
SELECT *
FROM "public".${sql(table.tableName)}
FROM ${sql(table.tableName)}
WHERE ${sql(primaryKeyFieldNames[0])} IN ${sql(pkArray)};
`;
} else {
Expand All @@ -83,19 +83,19 @@ module.exports.batchReadItemsInTable = (table, sql, pkArray) => {
module.exports.whereEqQuery = (table, sql, fieldName, value) => {
return sql`
SELECT *
FROM "public".${sql(table.tableName)}
FROM ${sql(table.tableName)}
WHERE ${sql(fieldName)} = ${value};
`;
};

module.exports.readLatestSyncedEventOnChainId = (sql, chainId) => sql`
SELECT *
FROM public.event_sync_state
FROM event_sync_state
WHERE chain_id = ${chainId}`;

module.exports.batchSetEventSyncState = (sql, entityDataArray) => {
return sql`
INSERT INTO public.event_sync_state
INSERT INTO event_sync_state
${sql(
entityDataArray,
"chain_id",
Expand All @@ -116,12 +116,12 @@ module.exports.batchSetEventSyncState = (sql, entityDataArray) => {

module.exports.readLatestChainMetadataState = (sql, chainId) => sql`
SELECT *
FROM public.chain_metadata
FROM chain_metadata
WHERE chain_id = ${chainId}`;

module.exports.batchSetChainMetadata = (sql, entityDataArray) => {
return sql`
INSERT INTO public.chain_metadata
INSERT INTO chain_metadata
${sql(
entityDataArray,
"chain_id",
Expand Down Expand Up @@ -154,7 +154,7 @@ module.exports.batchSetChainMetadata = (sql, entityDataArray) => {

const batchSetRawEventsCore = (sql, entityDataArray) => {
return sql`
INSERT INTO "public"."raw_events"
INSERT INTO "raw_events"
${sql(
entityDataArray,
"chain_id",
Expand All @@ -178,13 +178,13 @@ module.exports.batchSetRawEvents = (sql, entityDataArray) => {

module.exports.batchDeleteRawEvents = (sql, entityIdArray) => sql`
DELETE
FROM "public"."raw_events"
FROM "raw_events"
WHERE (chain_id, event_id) IN ${sql(entityIdArray)};`;
// end db operations for raw_events

const batchSetEndOfBlockRangeScannedDataCore = (sql, rowDataArray) => {
return sql`
INSERT INTO "public"."end_of_block_range_scanned_data"
INSERT INTO "end_of_block_range_scanned_data"
${sql(
rowDataArray,
"chain_id",
Expand All @@ -210,7 +210,7 @@ module.exports.batchSetEndOfBlockRangeScannedData = (sql, rowDataArray) => {

module.exports.readEndOfBlockRangeScannedDataForChain = (sql, chainId) => {
return sql`
SELECT * FROM "public"."end_of_block_range_scanned_data"
SELECT * FROM "end_of_block_range_scanned_data"
WHERE
chain_id = ${chainId}
ORDER BY block_number ASC;`;
Expand All @@ -224,7 +224,7 @@ module.exports.deleteStaleEndOfBlockRangeScannedDataForChain = (
) => {
return sql`
DELETE
FROM "public"."end_of_block_range_scanned_data"
FROM "end_of_block_range_scanned_data"
WHERE chain_id = ${chainId}
AND block_number < ${blockNumberThreshold}
AND block_timestamp < ${blockTimestampThreshold}
Expand All @@ -237,7 +237,7 @@ module.exports.readDynamicContractsOnChainIdAtOrBeforeBlockNumber = (
blockNumber
) => sql`
SELECT *
FROM "public"."dynamic_contract_registry"
FROM "dynamic_contract_registry"
WHERE registering_event_block_number <= ${blockNumber}
AND chain_id = ${chainId};`;

Expand All @@ -248,7 +248,7 @@ module.exports.readDynamicContractsOnChainIdMatchingEvents = (
) => {
return sql`
SELECT *
FROM "public"."dynamic_contract_registry"
FROM "dynamic_contract_registry"
WHERE chain_id = ${chainId}
AND (registering_event_contract_name, registering_event_name, registering_event_src_address) IN ${sql(
preRegisterEvents.map((item) => sql(item))
Expand All @@ -272,7 +272,7 @@ module.exports.getFirstChangeSerial_UnorderedMultichain = (
SELECT
MIN(serial) AS first_change_serial
FROM
public.${sql(makeHistoryTableName(entityName))}
${sql(makeHistoryTableName(entityName))}
WHERE
entity_history_chain_id = ${reorgChainId}
AND entity_history_block_number > ${safeBlockNumber}
Expand All @@ -292,7 +292,7 @@ module.exports.getFirstChangeSerial_OrderedMultichain = (
SELECT
MIN(serial) AS first_change_serial
FROM
public.${sql(makeHistoryTableName(entityName))}
${sql(makeHistoryTableName(entityName))}
WHERE
entity_history_block_timestamp > ${safeBlockTimestamp}
OR
Expand All @@ -317,7 +317,7 @@ module.exports.getFirstChangeEntityHistoryPerChain = (
SELECT DISTINCT
ON (entity_history_chain_id) *
FROM
public.${sql(makeHistoryTableName(entityName))}
${sql(makeHistoryTableName(entityName))}
WHERE
serial >= (
SELECT
Expand All @@ -344,7 +344,7 @@ module.exports.deleteRolledBackEntityHistory = (
)
-- Step 2: Delete all rows that have a serial >= the first change serial
DELETE FROM
public.${sql(makeHistoryTableName(entityName))}
${sql(makeHistoryTableName(entityName))}
WHERE
serial >= (
SELECT
Expand All @@ -371,7 +371,7 @@ module.exports.pruneStaleEntityHistory = (
SELECT
MIN(serial) AS first_change_serial
FROM
public.${sql(tableName)}
${sql(tableName)}
WHERE
${Utils.$$Array.interleave(
safeChainIdAndBlockNumberArray.map(
Expand All @@ -385,7 +385,7 @@ module.exports.pruneStaleEntityHistory = (
SELECT DISTINCT
ON (id) *
FROM
public.${sql(tableName)}
${sql(tableName)}
WHERE
serial >= (SELECT first_change_serial FROM first_change)
ORDER BY
Expand All @@ -400,7 +400,7 @@ module.exports.pruneStaleEntityHistory = (
prev.id,
prev.serial
FROM
public.${sql(tableName)} prev
${sql(tableName)} prev
INNER JOIN
items_in_reorg_threshold r
ON
Expand All @@ -415,7 +415,7 @@ module.exports.pruneStaleEntityHistory = (
: sql``
}
DELETE FROM
public.${sql(tableName)} eh
${sql(tableName)} eh
WHERE
-- Delete all entity history of entities that are not in the reorg threshold
eh.id NOT IN (SELECT id FROM items_in_reorg_threshold)
Expand All @@ -442,7 +442,7 @@ module.exports.getRollbackDiff = (sql, entityName, getFirstChangeSerial) => sql`
SELECT DISTINCT
ON (id) after.*
FROM
public.${sql(makeHistoryTableName(entityName))} after
${sql(makeHistoryTableName(entityName))} after
WHERE
after.serial >= (
SELECT
Expand All @@ -469,7 +469,7 @@ module.exports.getRollbackDiff = (sql, entityName, getFirstChangeSerial) => sql`
COALESCE(before.entity_history_log_index, 0) AS entity_history_log_index
FROM
-- Use a RIGHT JOIN, to ensure that nulls get returned if there is no "before" row
public.${sql(makeHistoryTableName(entityName))} before
${sql(makeHistoryTableName(entityName))} before
RIGHT JOIN rollback_ids after ON before.id = after.id
AND before.entity_history_block_timestamp = after.previous_entity_history_block_timestamp
AND before.entity_history_chain_id = after.previous_entity_history_chain_id
Expand Down
Loading