Skip to content

Commit

Permalink
connection.rs: Don't allow queries with values
Browse files Browse the repository at this point in the history
After serialization refactor it will be impossible to perform unprepared
query with values, because serializing values will require knowing
column types.

In order to make refactor easier and better split responsibility, this
commit removes `values` arguments from `Connection` methods, so that
it is callers responsibility to prepare the query if necessary.
  • Loading branch information
Lorak-mmk committed Nov 23, 2023
1 parent 33f0a0a commit 550ce4f
Show file tree
Hide file tree
Showing 4 changed files with 159 additions and 90 deletions.
115 changes: 63 additions & 52 deletions scylla/src/transport/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use scylla_cql::errors::TranslationError;
use scylla_cql::frame::request::options::Options;
use scylla_cql::frame::response::Error;
use scylla_cql::frame::types::SerialConsistency;
use scylla_cql::frame::value::SerializedValues;
use socket2::{SockRef, TcpKeepalive};
use tokio::io::{split, AsyncRead, AsyncWrite, AsyncWriteExt, BufReader, BufWriter};
use tokio::net::{TcpSocket, TcpStream};
Expand Down Expand Up @@ -596,7 +597,6 @@ impl Connection {
pub(crate) async fn query_single_page(
&self,
query: impl Into<Query>,
values: impl ValueList,
) -> Result<QueryResult, QueryError> {
let query: Query = query.into();

Expand All @@ -606,38 +606,30 @@ impl Connection {
.determine_consistency(self.config.default_consistency);
let serial_consistency = query.config.serial_consistency;

self.query_single_page_with_consistency(
query,
&values,
consistency,
serial_consistency.flatten(),
)
.await
self.query_single_page_with_consistency(query, consistency, serial_consistency.flatten())
.await
}

pub(crate) async fn query_single_page_with_consistency(
&self,
query: impl Into<Query>,
values: impl ValueList,
consistency: Consistency,
serial_consistency: Option<SerialConsistency>,
) -> Result<QueryResult, QueryError> {
let query: Query = query.into();
self.query_with_consistency(&query, &values, consistency, serial_consistency, None)
self.query_with_consistency(&query, consistency, serial_consistency, None)
.await?
.into_query_result()
}

pub(crate) async fn query(
&self,
query: &Query,
values: impl ValueList,
paging_state: Option<Bytes>,
) -> Result<QueryResponse, QueryError> {
// This method is used only for driver internal queries, so no need to consult execution profile here.
self.query_with_consistency(
query,
values,
query
.config
.determine_consistency(self.config.default_consistency),
Expand All @@ -650,33 +642,16 @@ impl Connection {
pub(crate) async fn query_with_consistency(
&self,
query: &Query,
values: impl ValueList,
consistency: Consistency,
serial_consistency: Option<SerialConsistency>,
paging_state: Option<Bytes>,
) -> Result<QueryResponse, QueryError> {
let serialized_values = values.serialized()?;

let values_size = serialized_values.size();
if values_size != 0 {
let prepared = self.prepare(query).await?;
return self
.execute_with_consistency(
&prepared,
values,
consistency,
serial_consistency,
paging_state,
)
.await;
}

let query_frame = query::Query {
contents: Cow::Borrowed(&query.contents),
parameters: query::QueryParameters {
consistency,
serial_consistency,
values: serialized_values,
values: Cow::Borrowed(SerializedValues::EMPTY),
page_size: query.get_page_size(),
paging_state,
timestamp: query.get_timestamp(),
Expand All @@ -687,6 +662,26 @@ impl Connection {
.await
}

#[allow(dead_code)]
pub(crate) async fn execute(
&self,
prepared: PreparedStatement,
values: impl ValueList,
paging_state: Option<Bytes>,
) -> Result<QueryResponse, QueryError> {
// This method is used only for driver internal queries, so no need to consult execution profile here.
self.execute_with_consistency(
&prepared,
values,
prepared
.config
.determine_consistency(self.config.default_consistency),
prepared.config.serial_consistency.flatten(),
paging_state,
)
.await
}

pub(crate) async fn execute_with_consistency(
&self,
prepared_statement: &PreparedStatement,
Expand Down Expand Up @@ -734,19 +729,33 @@ impl Connection {
pub(crate) async fn query_iter(
self: Arc<Self>,
query: Query,
values: impl ValueList,
) -> Result<RowIterator, QueryError> {
let serialized_values = values.serialized()?.into_owned();

let consistency = query
.config
.determine_consistency(self.config.default_consistency);
let serial_consistency = query.config.serial_consistency.flatten();

RowIterator::new_for_connection_query_iter(
query,
RowIterator::new_for_connection_query_iter(query, self, consistency, serial_consistency)
.await
}

/// Executes a prepared statements and fetches its results over multiple pages, using
/// the asynchronous iterator interface.
pub(crate) async fn execute_iter(
self: Arc<Self>,
prepared_statement: PreparedStatement,
values: impl ValueList,
) -> Result<RowIterator, QueryError> {
let consistency = prepared_statement
.config
.determine_consistency(self.config.default_consistency);
let serial_consistency = prepared_statement.config.serial_consistency.flatten();
let serialized = values.serialized()?.into_owned();

RowIterator::new_for_connection_execute_iter(
prepared_statement,
serialized,
self,
serialized_values,
consistency,
serial_consistency,
)
Expand Down Expand Up @@ -885,7 +894,7 @@ impl Connection {
false => format!("USE {}", keyspace_name.as_str()).into(),
};

let query_response = self.query(&query, (), None).await?;
let query_response = self.query(&query, None).await?;

match query_response.response {
Response::Result(result::Result::SetKeyspace(set_keyspace)) => {
Expand Down Expand Up @@ -929,7 +938,7 @@ impl Connection {

pub(crate) async fn fetch_schema_version(&self) -> Result<Uuid, QueryError> {
let (version_id,): (Uuid,) = self
.query_single_page(LOCAL_VERSION, &[])
.query_single_page(LOCAL_VERSION)
.await?
.rows
.ok_or(QueryError::ProtocolError("Version query returned not rows"))?
Expand Down Expand Up @@ -1833,7 +1842,6 @@ mod tests {
use super::ConnectionConfig;
use crate::query::Query;
use crate::transport::connection::open_connection;
use crate::transport::connection::QueryResponse;
use crate::transport::node::ResolvedContactPoint;
use crate::transport::topology::UntranslatedEndpoint;
use crate::utils::test_utils::unique_keyspace_name;
Expand Down Expand Up @@ -1914,7 +1922,7 @@ mod tests {
let select_query = Query::new("SELECT p FROM connection_query_iter_tab").with_page_size(7);
let empty_res = connection
.clone()
.query_iter(select_query.clone(), &[])
.query_iter(select_query.clone())
.await
.unwrap()
.try_collect::<Vec<_>>()
Expand All @@ -1927,15 +1935,18 @@ mod tests {
let mut insert_futures = Vec::new();
let insert_query =
Query::new("INSERT INTO connection_query_iter_tab (p) VALUES (?)").with_page_size(7);
let prepared = connection.prepare(&insert_query).await.unwrap();
for v in &values {
insert_futures.push(connection.query_single_page(insert_query.clone(), (v,)));
let prepared_clone = prepared.clone();
let fut = async { connection.execute(prepared_clone, (*v,), None).await };
insert_futures.push(fut);
}

futures::future::try_join_all(insert_futures).await.unwrap();

let mut results: Vec<i32> = connection
.clone()
.query_iter(select_query.clone(), &[])
.query_iter(select_query.clone())
.await
.unwrap()
.into_typed::<(i32,)>()
Expand All @@ -1947,7 +1958,9 @@ mod tests {

// 3. INSERT query_iter should work and not return any rows.
let insert_res1 = connection
.query_iter(insert_query, (0,))
.query_iter(Query::new(
"INSERT INTO connection_query_iter_tab (p) VALUES (0)",
))
.await
.unwrap()
.try_collect::<Vec<_>>()
Expand Down Expand Up @@ -2007,10 +2020,7 @@ mod tests {
.await
.unwrap();

connection
.query(&"TRUNCATE t".into(), (), None)
.await
.unwrap();
connection.query(&"TRUNCATE t".into(), None).await.unwrap();

let mut futs = Vec::new();

Expand All @@ -2025,8 +2035,9 @@ mod tests {
let q = Query::new("INSERT INTO t (p, v) VALUES (?, ?)");
let conn = conn.clone();
async move {
let response: QueryResponse = conn
.query(&q, (j, vec![j as u8; j as usize]), None)
let prepared = conn.prepare(&q).await.unwrap();
let response = conn
.execute(prepared.clone(), (j, vec![j as u8; j as usize]), None)
.await
.unwrap();
// QueryResponse might contain an error - make sure that there were no errors
Expand All @@ -2045,7 +2056,7 @@ mod tests {
// Check that everything was written properly
let range_end = arithmetic_sequence_sum(NUM_BATCHES);
let mut results = connection
.query(&"SELECT p, v FROM t".into(), (), None)
.query(&"SELECT p, v FROM t".into(), None)
.await
.unwrap()
.into_query_result()
Expand Down Expand Up @@ -2198,7 +2209,7 @@ mod tests {
// As everything is normal, these queries should succeed.
for _ in 0..3 {
tokio::time::sleep(Duration::from_millis(500)).await;
conn.query_single_page("SELECT host_id FROM system.local", ())
conn.query_single_page("SELECT host_id FROM system.local")
.await
.unwrap();
}
Expand All @@ -2218,7 +2229,7 @@ mod tests {

// As the router is invalidated, all further queries should immediately
// return error.
conn.query_single_page("SELECT host_id FROM system.local", ())
conn.query_single_page("SELECT host_id FROM system.local")
.await
.unwrap_err();

Expand Down
58 changes: 42 additions & 16 deletions scylla/src/transport/iterator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,6 @@ impl RowIterator {

pub(crate) async fn new_for_query(
mut query: Query,
values: SerializedValues,
execution_profile: Arc<ExecutionProfileInner>,
cluster_data: Arc<ClusterData>,
metrics: Arc<Metrics>,
Expand Down Expand Up @@ -162,29 +161,27 @@ impl RowIterator {
let parent_span = tracing::Span::current();
let worker_task = async move {
let query_ref = &query;
let values_ref = &values;

let choose_connection = |node: Arc<Node>| async move { node.random_connection().await };

let page_query = |connection: Arc<Connection>,
consistency: Consistency,
paging_state: Option<Bytes>| async move {
connection
.query_with_consistency(
query_ref,
values_ref,
consistency,
serial_consistency,
paging_state,
)
.await
paging_state: Option<Bytes>| {
async move {
connection
.query_with_consistency(
query_ref,
consistency,
serial_consistency,
paging_state,
)
.await
}
};

let query_ref = &query;
let serialized_values_size = values.size();

let span_creator =
move || RequestSpan::new_query(&query_ref.contents, serialized_values_size);
let span_creator = move || RequestSpan::new_query(&query_ref.contents, 0);

let worker = RowIteratorWorker {
sender: sender.into(),
Expand Down Expand Up @@ -337,7 +334,6 @@ impl RowIterator {
pub(crate) async fn new_for_connection_query_iter(
mut query: Query,
connection: Arc<Connection>,
values: SerializedValues,
consistency: Consistency,
serial_consistency: Option<SerialConsistency>,
) -> Result<RowIterator, QueryError> {
Expand All @@ -352,6 +348,36 @@ impl RowIterator {
fetcher: |paging_state| {
connection.query_with_consistency(
&query,
consistency,
serial_consistency,
paging_state,
)
},
};
worker.work().await
};

Self::new_from_worker_future(worker_task, receiver).await
}

pub(crate) async fn new_for_connection_execute_iter(
mut prepared: PreparedStatement,
values: SerializedValues,
connection: Arc<Connection>,
consistency: Consistency,
serial_consistency: Option<SerialConsistency>,
) -> Result<RowIterator, QueryError> {
if prepared.get_page_size().is_none() {
prepared.set_page_size(DEFAULT_ITER_PAGE_SIZE);
}
let (sender, receiver) = mpsc::channel::<Result<ReceivedPage, QueryError>>(1);

let worker_task = async move {
let worker = SingleConnectionRowIteratorWorker {
sender: sender.into(),
fetcher: |paging_state| {
connection.execute_with_consistency(
&prepared,
&values,
consistency,
serial_consistency,
Expand Down
Loading

0 comments on commit 550ce4f

Please sign in to comment.