Add filtering, sorting, getsqlstatement

This commit is contained in:
Jim Myers 2023-04-21 17:24:10 -04:00
parent f18bee5f8e
commit c2e6b5e3be
12 changed files with 185 additions and 11 deletions

View File

@ -1,2 +1,3 @@
from .models import *
from .rpc import RPC # noqa: F401
from .shroomdk import ShroomDK # noqa: F401

View File

@ -1,3 +1,4 @@
from .not_found_error import NotFoundError # noqa: F401
from .query_run_errors import ( # noqa: F401
QueryRunCancelledError,
QueryRunExecutionError,

View File

@ -0,0 +1,13 @@
from typing import Union
from .base_error import BaseError
class NotFoundError(BaseError):
"""
When an object is not found on the server.
"""
def __init__(self, message: Union[str, None]):
self.message = message
super().__init__(self.message)

View File

@ -1,7 +1,8 @@
import json
from typing import Union
from typing import List, Optional, Union
from shroomdk.errors import (
NotFoundError,
QueryRunCancelledError,
QueryRunExecutionError,
QueryRunTimeoutError,
@ -18,13 +19,17 @@ from shroomdk.models import (
from shroomdk.models.compass.core.page import Page
from shroomdk.models.compass.core.query_run import QueryRun
from shroomdk.models.compass.core.result_format import ResultFormat
from shroomdk.models.compass.core.sql_statement import SqlStatement
from shroomdk.models.compass.core.tags import Tags
from shroomdk.models.compass.create_query_run import CreateQueryRunRpcParams
from shroomdk.models.compass.get_query_run import GetQueryRunRpcRequestParams
from shroomdk.models.compass.get_query_run_results import (
Filter,
GetQueryRunResultsRpcParams,
GetQueryRunResultsRpcResult,
SortBy,
)
from shroomdk.models.compass.get_sql_statement import GetSqlStatementParams
from shroomdk.rpc import RPC
from shroomdk.utils.sleep import get_elapsed_linear_seconds, linear_backoff
@ -66,7 +71,7 @@ class CompassQueryIntegration(object):
if not created_query.result or not created_query.result.queryRun:
raise SDKError("expected `query_run` from server but got `None`")
query_run = self._get_query_run(
query_run = self._get_query_run_loop(
created_query.result.queryRun.id,
page_number=query.page_number,
page_size=query.page_size,
@ -87,9 +92,70 @@ class CompassQueryIntegration(object):
query_result=query_result,
).build()
def get_sql_statement(self, sql_statement_id: str) -> SqlStatement:
response = self.rpc.get_sql_statement(
GetSqlStatementParams(**{"sqlStatementId": sql_statement_id})
)
if response.error or not response.result:
raise NotFoundError(f"SQLStatement<{sql_statement_id}> not found")
return response.result.sqlStatement
def get_query_run(self, query_run_id: str) -> QueryRun:
response = self.rpc.get_query_run(
GetQueryRunRpcRequestParams(queryRunId=query_run_id)
)
if response.error or not response.result:
raise NotFoundError(f"QueryRun<{query_run_id}> not found")
return response.result.queryRun
def get_query_results(
self,
query_run_id: str,
page_number: int = 1,
page_size: int = 100000,
filters: Optional[Union[List[Filter], None]] = [],
sort_by: Optional[Union[List[SortBy], None]] = [],
) -> QueryResultSet:
query_result = self._get_query_results(
query_run_id,
page_number=page_number if page_number else 1,
page_size=page_size if page_size else 10000,
filters=filters,
sort_by=sort_by,
)
query_run = (
query_result.redirectedToQueryRun
if query_result.redirectedToQueryRun
else query_result.originalQueryRun
)
return QueryResultSetBuilder(
query_run=query_run,
query_result=query_result,
).build()
def _get_query_results(
self, query_run_id: str, page_number: int = 1, page_size: int = 100000
self,
query_run_id: str,
page_number: int = 1,
page_size: int = 100000,
filters: Optional[Union[List[Filter], None]] = [],
sort_by: Optional[Union[List[SortBy], None]] = [],
) -> GetQueryRunResultsRpcResult:
# f2 = []
# if filters:
# for f in filters:
# d = f.dict()
# d2 = {}
# for k, v in d.items():
# if v is not None:
# d2[k] = v
# f2.append(Filter(**d2))
query_results_resp = self.rpc.get_query_result(
GetQueryRunResultsRpcParams(
queryRunId=query_run_id,
@ -98,6 +164,8 @@ class CompassQueryIntegration(object):
number=page_number,
size=page_size,
),
filters=filters,
sortBy=sort_by,
)
)
@ -121,7 +189,7 @@ class CompassQueryIntegration(object):
)
return Query(**query_default_dict)
def _get_query_run(
def _get_query_run_loop(
self,
query_run_id: str,
page_number: int = 1,
@ -184,7 +252,7 @@ class CompassQueryIntegration(object):
raise QueryRunTimeoutError(elapsed_seconds)
return self._get_query_run(
return self._get_query_run_loop(
query_run_id,
page_number,
page_size,

View File

@ -1,3 +1,4 @@
from .compass import Filter, SortBy # noqa: F401
from .query import Query # noqa: F401
from .query_defaults import QueryDefaults # noqa: F401
from .query_result_set import QueryResultSet # noqa: F401

View File

@ -0,0 +1 @@
from .get_query_run_results import Filter, SortBy

View File

@ -26,6 +26,10 @@ class Filter(BaseModel):
class Config:
fields = {"in_": "in"}
def dict(self, *args, **kwargs) -> dict:
kwargs.setdefault("exclude_none", True) # Exclude keys with None values
return super().dict(*args, **kwargs)
class SortBy(BaseModel):
column: str
@ -39,11 +43,19 @@ class GetQueryRunResultsRpcParams(BaseModel):
sortBy: Optional[Union[List[SortBy], None]] = []
page: Page
def dict(self, *args, **kwargs) -> dict:
kwargs.setdefault("exclude_none", True) # Exclude keys with None values
return super().dict(*args, **kwargs)
class GetQueryRunResultsRpcRequest(RpcRequest):
method: str = "getQueryRunResults"
params: List[GetQueryRunResultsRpcParams]
def dict(self, *args, **kwargs) -> dict:
kwargs.setdefault("exclude_none", True) # Exclude keys with None values
return super().dict(*args, **kwargs)
# Response
class GetQueryRunResultsRpcResult(BaseModel):

View File

@ -0,0 +1,26 @@
from typing import Dict, List, Optional, Union
from pydantic import BaseModel
from .core.rpc_request import RpcRequest
from .core.rpc_response import RpcResponse
from .core.sql_statement import SqlStatement
# Request
class GetSqlStatementParams(BaseModel):
sqlStatementId: str
class GetSqlStatementRequest(RpcRequest):
method: str = "getSqlStatement"
params: List[GetSqlStatementParams]
# Response
class GetSqlStatemetnResult(BaseModel):
sqlStatement: SqlStatement
class GetSqlStatementResponse(RpcResponse):
result: Union[GetSqlStatemetnResult, None]

View File

@ -7,6 +7,7 @@ from .query_run_stats import QueryRunStats
class QueryResultSet(BaseModel):
query_id: Union[str, None] = Field(None, description="The server id of the query")
status: str = Field(
False, description="The status of the query (`PENDING`, `FINISHED`, `ERROR`)"
)

View File

@ -4,6 +4,11 @@ from typing import List
import requests
from requests.adapters import HTTPAdapter, Retry
from shroomdk.models.compass.get_sql_statement import (
GetSqlStatementParams,
GetSqlStatementRequest,
GetSqlStatementResponse,
)
from .errors.server_error import ServerError
from .errors.user_error import UserError
@ -81,6 +86,21 @@ class RPC(object):
return get_query_run_resp
def get_sql_statement(
self, params: GetSqlStatementParams
) -> GetSqlStatementResponse:
result = self._session.post(
self.url,
data=json.dumps(GetSqlStatementRequest(params=[params]).dict()),
headers=self._headers,
)
data = self._handle_response(result, "getSqlStatement")
get_sql_statement_resp = GetSqlStatementResponse(**data)
return get_sql_statement_resp
def get_query_result(
self, params: GetQueryRunResultsRpcParams
) -> GetQueryRunResultsRpcResponse:
@ -91,7 +111,6 @@ class RPC(object):
)
data = self._handle_response(result, "getQueryRunResults")
get_query_run_results_resp = GetQueryRunResultsRpcResponse(**data)
return get_query_run_results_resp

View File

@ -1,9 +1,17 @@
from typing import List, Optional, Union
from shroomdk.integrations.query_integration.compass_query_integration import (
CompassQueryIntegration,
)
from shroomdk.models import Query
from shroomdk.models.compass.core.query_run import QueryRun
from shroomdk.models.compass.core.sql_statement import SqlStatement
from shroomdk.models.compass.get_sql_statement import GetSqlStatementParams
from shroomdk.models.query_result_set import QueryResultSet
from shroomdk.rpc import RPC
from .models import Filter, SortBy
API_BASE_URL = "https://rpc-prod.flompass.pizza"
SDK_VERSION = "2.0.0"
@ -13,6 +21,7 @@ SDK_PACKAGE = "python"
class ShroomDK(object):
def __init__(self, api_key: str, api_base_url: str = API_BASE_URL):
self.rpc = RPC(api_base_url, api_key)
self.query_integration = CompassQueryIntegration(self.rpc)
def query(
self,
@ -26,7 +35,7 @@ class ShroomDK(object):
page_number=1,
data_source="snowflake-default",
data_provider="flipside",
):
) -> QueryResultSet:
query_integration = CompassQueryIntegration(self.rpc)
return query_integration.run(
@ -45,3 +54,25 @@ class ShroomDK(object):
data_provider=data_provider,
)
)
def get_query_run(self, query_run_id: str) -> QueryRun:
return self.query_integration.get_query_run(query_run_id)
def get_query_results(
self,
query_run_id: str,
page_number: int = 1,
page_size: int = 10000,
filters: Optional[Union[List[Filter], None]] = [],
sort_by: Optional[Union[List[SortBy], None]] = [],
) -> QueryResultSet:
return self.query_integration.get_query_results(
query_run_id,
page_number=page_number,
page_size=page_size,
filters=filters,
sort_by=sort_by,
)
def get_sql_statement(self, sql_statement_id: str) -> SqlStatement:
return self.query_integration.get_sql_statement(sql_statement_id)

View File

@ -138,7 +138,7 @@ def test_get_query_run_query(requests_mock):
)
try:
result = qi._get_query_run(
result = qi._get_query_run_loop(
"test_query_id",
page_number=page_number,
page_size=page_size,
@ -164,7 +164,7 @@ def test_get_query_run_query(requests_mock):
)
try:
result = qi._get_query_run(
result = qi._get_query_run_loop(
"test_query_id",
page_number=page_number,
page_size=page_size,
@ -188,7 +188,7 @@ def test_get_query_run_query(requests_mock):
reason="OK",
)
result = qi._get_query_run(
result = qi._get_query_run_loop(
"test_query_id",
page_number=page_number,
page_size=page_size,
@ -213,7 +213,7 @@ def test_get_query_run_query(requests_mock):
)
try:
result = qi._get_query_run(
result = qi._get_query_run_loop(
"test_query_id",
page_number=page_number,
page_size=page_size,