From c2e6b5e3be705e675c8183a9d43070a67bbf65f3 Mon Sep 17 00:00:00 2001 From: Jim Myers Date: Fri, 21 Apr 2023 17:24:10 -0400 Subject: [PATCH] Add filtering, sorting, getsqlstatement --- python/shroomdk/__init__.py | 1 + python/shroomdk/errors/__init__.py | 1 + python/shroomdk/errors/not_found_error.py | 13 ++++ .../compass_query_integration.py | 78 +++++++++++++++++-- python/shroomdk/models/__init__.py | 1 + python/shroomdk/models/compass/__init__.py | 1 + .../models/compass/get_query_run_results.py | 12 +++ .../models/compass/get_sql_statement.py | 26 +++++++ python/shroomdk/models/query_result_set.py | 1 + python/shroomdk/rpc.py | 21 ++++- python/shroomdk/shroomdk.py | 33 +++++++- .../test_query_compass_integration.py | 8 +- 12 files changed, 185 insertions(+), 11 deletions(-) create mode 100644 python/shroomdk/errors/not_found_error.py create mode 100644 python/shroomdk/models/compass/get_sql_statement.py diff --git a/python/shroomdk/__init__.py b/python/shroomdk/__init__.py index 10f7b11..0eab007 100644 --- a/python/shroomdk/__init__.py +++ b/python/shroomdk/__init__.py @@ -1,2 +1,3 @@ +from .models import * from .rpc import RPC # noqa: F401 from .shroomdk import ShroomDK # noqa: F401 diff --git a/python/shroomdk/errors/__init__.py b/python/shroomdk/errors/__init__.py index 03ebd9d..fbd6d5a 100644 --- a/python/shroomdk/errors/__init__.py +++ b/python/shroomdk/errors/__init__.py @@ -1,3 +1,4 @@ +from .not_found_error import NotFoundError # noqa: F401 from .query_run_errors import ( # noqa: F401 QueryRunCancelledError, QueryRunExecutionError, diff --git a/python/shroomdk/errors/not_found_error.py b/python/shroomdk/errors/not_found_error.py new file mode 100644 index 0000000..8267fe9 --- /dev/null +++ b/python/shroomdk/errors/not_found_error.py @@ -0,0 +1,13 @@ +from typing import Union + +from .base_error import BaseError + + +class NotFoundError(BaseError): + """ + When an object is not found on the server. + """ + + def __init__(self, message: Union[str, None]): + self.message = message + super().__init__(self.message) diff --git a/python/shroomdk/integrations/query_integration/compass_query_integration.py b/python/shroomdk/integrations/query_integration/compass_query_integration.py index 53d0f85..e5af70b 100644 --- a/python/shroomdk/integrations/query_integration/compass_query_integration.py +++ b/python/shroomdk/integrations/query_integration/compass_query_integration.py @@ -1,7 +1,8 @@ import json -from typing import Union +from typing import List, Optional, Union from shroomdk.errors import ( + NotFoundError, QueryRunCancelledError, QueryRunExecutionError, QueryRunTimeoutError, @@ -18,13 +19,17 @@ from shroomdk.models import ( from shroomdk.models.compass.core.page import Page from shroomdk.models.compass.core.query_run import QueryRun from shroomdk.models.compass.core.result_format import ResultFormat +from shroomdk.models.compass.core.sql_statement import SqlStatement from shroomdk.models.compass.core.tags import Tags from shroomdk.models.compass.create_query_run import CreateQueryRunRpcParams from shroomdk.models.compass.get_query_run import GetQueryRunRpcRequestParams from shroomdk.models.compass.get_query_run_results import ( + Filter, GetQueryRunResultsRpcParams, GetQueryRunResultsRpcResult, + SortBy, ) +from shroomdk.models.compass.get_sql_statement import GetSqlStatementParams from shroomdk.rpc import RPC from shroomdk.utils.sleep import get_elapsed_linear_seconds, linear_backoff @@ -66,7 +71,7 @@ class CompassQueryIntegration(object): if not created_query.result or not created_query.result.queryRun: raise SDKError("expected `query_run` from server but got `None`") - query_run = self._get_query_run( + query_run = self._get_query_run_loop( created_query.result.queryRun.id, page_number=query.page_number, page_size=query.page_size, @@ -87,9 +92,70 @@ class CompassQueryIntegration(object): query_result=query_result, ).build() + def get_sql_statement(self, sql_statement_id: str) -> SqlStatement: + response = self.rpc.get_sql_statement( + GetSqlStatementParams(**{"sqlStatementId": sql_statement_id}) + ) + + if response.error or not response.result: + raise NotFoundError(f"SQLStatement<{sql_statement_id}> not found") + + return response.result.sqlStatement + + def get_query_run(self, query_run_id: str) -> QueryRun: + response = self.rpc.get_query_run( + GetQueryRunRpcRequestParams(queryRunId=query_run_id) + ) + + if response.error or not response.result: + raise NotFoundError(f"QueryRun<{query_run_id}> not found") + + return response.result.queryRun + + def get_query_results( + self, + query_run_id: str, + page_number: int = 1, + page_size: int = 100000, + filters: Optional[Union[List[Filter], None]] = [], + sort_by: Optional[Union[List[SortBy], None]] = [], + ) -> QueryResultSet: + query_result = self._get_query_results( + query_run_id, + page_number=page_number if page_number else 1, + page_size=page_size if page_size else 10000, + filters=filters, + sort_by=sort_by, + ) + + query_run = ( + query_result.redirectedToQueryRun + if query_result.redirectedToQueryRun + else query_result.originalQueryRun + ) + return QueryResultSetBuilder( + query_run=query_run, + query_result=query_result, + ).build() + def _get_query_results( - self, query_run_id: str, page_number: int = 1, page_size: int = 100000 + self, + query_run_id: str, + page_number: int = 1, + page_size: int = 100000, + filters: Optional[Union[List[Filter], None]] = [], + sort_by: Optional[Union[List[SortBy], None]] = [], ) -> GetQueryRunResultsRpcResult: + # f2 = [] + # if filters: + # for f in filters: + # d = f.dict() + # d2 = {} + # for k, v in d.items(): + # if v is not None: + # d2[k] = v + # f2.append(Filter(**d2)) + query_results_resp = self.rpc.get_query_result( GetQueryRunResultsRpcParams( queryRunId=query_run_id, @@ -98,6 +164,8 @@ class CompassQueryIntegration(object): number=page_number, size=page_size, ), + filters=filters, + sortBy=sort_by, ) ) @@ -121,7 +189,7 @@ class CompassQueryIntegration(object): ) return Query(**query_default_dict) - def _get_query_run( + def _get_query_run_loop( self, query_run_id: str, page_number: int = 1, @@ -184,7 +252,7 @@ class CompassQueryIntegration(object): raise QueryRunTimeoutError(elapsed_seconds) - return self._get_query_run( + return self._get_query_run_loop( query_run_id, page_number, page_size, diff --git a/python/shroomdk/models/__init__.py b/python/shroomdk/models/__init__.py index 80b74a8..e5da855 100644 --- a/python/shroomdk/models/__init__.py +++ b/python/shroomdk/models/__init__.py @@ -1,3 +1,4 @@ +from .compass import Filter, SortBy # noqa: F401 from .query import Query # noqa: F401 from .query_defaults import QueryDefaults # noqa: F401 from .query_result_set import QueryResultSet # noqa: F401 diff --git a/python/shroomdk/models/compass/__init__.py b/python/shroomdk/models/compass/__init__.py index e69de29..4e1781e 100644 --- a/python/shroomdk/models/compass/__init__.py +++ b/python/shroomdk/models/compass/__init__.py @@ -0,0 +1 @@ +from .get_query_run_results import Filter, SortBy diff --git a/python/shroomdk/models/compass/get_query_run_results.py b/python/shroomdk/models/compass/get_query_run_results.py index 25c11e8..332ea99 100644 --- a/python/shroomdk/models/compass/get_query_run_results.py +++ b/python/shroomdk/models/compass/get_query_run_results.py @@ -26,6 +26,10 @@ class Filter(BaseModel): class Config: fields = {"in_": "in"} + def dict(self, *args, **kwargs) -> dict: + kwargs.setdefault("exclude_none", True) # Exclude keys with None values + return super().dict(*args, **kwargs) + class SortBy(BaseModel): column: str @@ -39,11 +43,19 @@ class GetQueryRunResultsRpcParams(BaseModel): sortBy: Optional[Union[List[SortBy], None]] = [] page: Page + def dict(self, *args, **kwargs) -> dict: + kwargs.setdefault("exclude_none", True) # Exclude keys with None values + return super().dict(*args, **kwargs) + class GetQueryRunResultsRpcRequest(RpcRequest): method: str = "getQueryRunResults" params: List[GetQueryRunResultsRpcParams] + def dict(self, *args, **kwargs) -> dict: + kwargs.setdefault("exclude_none", True) # Exclude keys with None values + return super().dict(*args, **kwargs) + # Response class GetQueryRunResultsRpcResult(BaseModel): diff --git a/python/shroomdk/models/compass/get_sql_statement.py b/python/shroomdk/models/compass/get_sql_statement.py new file mode 100644 index 0000000..2d8e546 --- /dev/null +++ b/python/shroomdk/models/compass/get_sql_statement.py @@ -0,0 +1,26 @@ +from typing import Dict, List, Optional, Union + +from pydantic import BaseModel + +from .core.rpc_request import RpcRequest +from .core.rpc_response import RpcResponse +from .core.sql_statement import SqlStatement + + +# Request +class GetSqlStatementParams(BaseModel): + sqlStatementId: str + + +class GetSqlStatementRequest(RpcRequest): + method: str = "getSqlStatement" + params: List[GetSqlStatementParams] + + +# Response +class GetSqlStatemetnResult(BaseModel): + sqlStatement: SqlStatement + + +class GetSqlStatementResponse(RpcResponse): + result: Union[GetSqlStatemetnResult, None] diff --git a/python/shroomdk/models/query_result_set.py b/python/shroomdk/models/query_result_set.py index 7a0aa17..5d95178 100644 --- a/python/shroomdk/models/query_result_set.py +++ b/python/shroomdk/models/query_result_set.py @@ -7,6 +7,7 @@ from .query_run_stats import QueryRunStats class QueryResultSet(BaseModel): query_id: Union[str, None] = Field(None, description="The server id of the query") + status: str = Field( False, description="The status of the query (`PENDING`, `FINISHED`, `ERROR`)" ) diff --git a/python/shroomdk/rpc.py b/python/shroomdk/rpc.py index 1212cb9..4ad5fcd 100644 --- a/python/shroomdk/rpc.py +++ b/python/shroomdk/rpc.py @@ -4,6 +4,11 @@ from typing import List import requests from requests.adapters import HTTPAdapter, Retry +from shroomdk.models.compass.get_sql_statement import ( + GetSqlStatementParams, + GetSqlStatementRequest, + GetSqlStatementResponse, +) from .errors.server_error import ServerError from .errors.user_error import UserError @@ -81,6 +86,21 @@ class RPC(object): return get_query_run_resp + def get_sql_statement( + self, params: GetSqlStatementParams + ) -> GetSqlStatementResponse: + result = self._session.post( + self.url, + data=json.dumps(GetSqlStatementRequest(params=[params]).dict()), + headers=self._headers, + ) + + data = self._handle_response(result, "getSqlStatement") + + get_sql_statement_resp = GetSqlStatementResponse(**data) + + return get_sql_statement_resp + def get_query_result( self, params: GetQueryRunResultsRpcParams ) -> GetQueryRunResultsRpcResponse: @@ -91,7 +111,6 @@ class RPC(object): ) data = self._handle_response(result, "getQueryRunResults") - get_query_run_results_resp = GetQueryRunResultsRpcResponse(**data) return get_query_run_results_resp diff --git a/python/shroomdk/shroomdk.py b/python/shroomdk/shroomdk.py index 47476f0..1051ad7 100644 --- a/python/shroomdk/shroomdk.py +++ b/python/shroomdk/shroomdk.py @@ -1,9 +1,17 @@ +from typing import List, Optional, Union + from shroomdk.integrations.query_integration.compass_query_integration import ( CompassQueryIntegration, ) from shroomdk.models import Query +from shroomdk.models.compass.core.query_run import QueryRun +from shroomdk.models.compass.core.sql_statement import SqlStatement +from shroomdk.models.compass.get_sql_statement import GetSqlStatementParams +from shroomdk.models.query_result_set import QueryResultSet from shroomdk.rpc import RPC +from .models import Filter, SortBy + API_BASE_URL = "https://rpc-prod.flompass.pizza" SDK_VERSION = "2.0.0" @@ -13,6 +21,7 @@ SDK_PACKAGE = "python" class ShroomDK(object): def __init__(self, api_key: str, api_base_url: str = API_BASE_URL): self.rpc = RPC(api_base_url, api_key) + self.query_integration = CompassQueryIntegration(self.rpc) def query( self, @@ -26,7 +35,7 @@ class ShroomDK(object): page_number=1, data_source="snowflake-default", data_provider="flipside", - ): + ) -> QueryResultSet: query_integration = CompassQueryIntegration(self.rpc) return query_integration.run( @@ -45,3 +54,25 @@ class ShroomDK(object): data_provider=data_provider, ) ) + + def get_query_run(self, query_run_id: str) -> QueryRun: + return self.query_integration.get_query_run(query_run_id) + + def get_query_results( + self, + query_run_id: str, + page_number: int = 1, + page_size: int = 10000, + filters: Optional[Union[List[Filter], None]] = [], + sort_by: Optional[Union[List[SortBy], None]] = [], + ) -> QueryResultSet: + return self.query_integration.get_query_results( + query_run_id, + page_number=page_number, + page_size=page_size, + filters=filters, + sort_by=sort_by, + ) + + def get_sql_statement(self, sql_statement_id: str) -> SqlStatement: + return self.query_integration.get_sql_statement(sql_statement_id) diff --git a/python/shroomdk/tests/integrations/query_integration/test_query_compass_integration.py b/python/shroomdk/tests/integrations/query_integration/test_query_compass_integration.py index 6e7800d..5030191 100644 --- a/python/shroomdk/tests/integrations/query_integration/test_query_compass_integration.py +++ b/python/shroomdk/tests/integrations/query_integration/test_query_compass_integration.py @@ -138,7 +138,7 @@ def test_get_query_run_query(requests_mock): ) try: - result = qi._get_query_run( + result = qi._get_query_run_loop( "test_query_id", page_number=page_number, page_size=page_size, @@ -164,7 +164,7 @@ def test_get_query_run_query(requests_mock): ) try: - result = qi._get_query_run( + result = qi._get_query_run_loop( "test_query_id", page_number=page_number, page_size=page_size, @@ -188,7 +188,7 @@ def test_get_query_run_query(requests_mock): reason="OK", ) - result = qi._get_query_run( + result = qi._get_query_run_loop( "test_query_id", page_number=page_number, page_size=page_size, @@ -213,7 +213,7 @@ def test_get_query_run_query(requests_mock): ) try: - result = qi._get_query_run( + result = qi._get_query_run_loop( "test_query_id", page_number=page_number, page_size=page_size,