Enhance query integration and model definitions with optional fields and improved defaults

This commit is contained in:
Paul Mikulskis 2025-01-23 10:08:11 -05:00
parent 2a5e4c6036
commit 751f1adc70
No known key found for this signature in database
21 changed files with 104 additions and 81 deletions

2
.gitignore vendored
View File

@ -33,3 +33,5 @@ r/shroomDK_0.1.0.tar.gz
python-sdk-example.py
r/shroomDK/api_key.txt
r/shroomDK/test_of_page2_issue.R
python/venv/
venv/

0
python/log.txt Normal file
View File

View File

@ -1,2 +1,3 @@
pytest==6.2.4
freezegun==1.1.0
freezegun==1.1.0
requests-mock==1.11.0

View File

@ -39,21 +39,22 @@ class CompassQueryIntegration(object):
def run(self, query: Query) -> QueryResultSet:
query = self._set_query_defaults(query)
# Use the default values from Query class when None
ttl_hours = int((query.ttl_minutes or 0) / 60)
max_age_minutes = query.max_age_minutes or 5 # default from Query class
retry_interval_seconds = query.retry_interval_seconds or 1 # default from Query class
create_query_run_params = CreateQueryRunRpcParams(
resultTTLHours=int(query.ttl_minutes / 60)
if query.ttl_minutes
else DEFAULTS.ttl_minutes,
sql=query.sql,
maxAgeMinutes=query.max_age_minutes
if query.max_age_minutes
else DEFAULTS.max_age_minutes,
resultTTLHours=ttl_hours,
sql=query.sql or "",
maxAgeMinutes=max_age_minutes,
tags=Tags(
sdk_language="python",
sdk_package=query.sdk_package,
sdk_version=query.sdk_version,
),
dataSource=query.data_source if query.data_source else "snowflake-default",
dataProvider=query.data_provider if query.data_provider else "flipside",
dataSource=query.data_source or "snowflake-default",
dataProvider=query.data_provider or "flipside",
)
created_query = self.rpc.create_query(create_query_run_params)
if created_query.error:
@ -67,18 +68,16 @@ class CompassQueryIntegration(object):
query_run = self._get_query_run_loop(
created_query.result.queryRun.id,
page_number=query.page_number,
page_size=query.page_size,
timeout_minutes=query.timeout_minutes if query.timeout_minutes else 20,
retry_interval_seconds=query.retry_interval_seconds
if query.retry_interval_seconds
else 1,
page_number=query.page_number or 1,
page_size=query.page_size or 100000,
timeout_minutes=query.timeout_minutes or 20,
retry_interval_seconds=retry_interval_seconds,
)
query_result = self._get_query_results(
query_run.id,
page_number=query.page_number if query.page_number else 1,
page_size=query.page_size if query.page_size else 100000,
page_number=query.page_number or 1,
page_size=query.page_size or 100000,
)
return QueryResultSetBuilder(

View File

@ -23,4 +23,4 @@ class CancelQueryRunRpcResult(BaseModel):
class CancelQueryRunRpcResponse(RpcResponse):
result: Union[CancelQueryRunRpcResult, None]
result: Union[CancelQueryRunRpcResult, None] = None

View File

@ -11,23 +11,23 @@ class QueryRun(BaseModel):
sqlStatementId: str
state: str
path: str
fileCount: Optional[int]
lastFileNumber: Optional[int]
fileNames: Optional[str]
errorName: Optional[str]
errorMessage: Optional[str]
errorData: Optional[Any]
dataSourceQueryId: Optional[str]
dataSourceSessionId: Optional[str]
startedAt: Optional[str]
queryRunningEndedAt: Optional[str]
queryStreamingEndedAt: Optional[str]
endedAt: Optional[str]
rowCount: Optional[int]
totalSize: Optional[int]
fileCount: Optional[int] = None
lastFileNumber: Optional[int] = None
fileNames: Optional[str] = None
errorName: Optional[str] = None
errorMessage: Optional[str] = None
errorData: Optional[Any] = None
dataSourceQueryId: Optional[str] = None
dataSourceSessionId: Optional[str] = None
startedAt: Optional[str] = None
queryRunningEndedAt: Optional[str] = None
queryStreamingEndedAt: Optional[str] = None
endedAt: Optional[str] = None
rowCount: Optional[int] = None
totalSize: Optional[int] = None
tags: Tags
dataSourceId: str
userId: str
createdAt: str
updatedAt: datetime
archivedAt: Optional[datetime]
archivedAt: Optional[datetime] = None

View File

@ -6,4 +6,4 @@ from pydantic import BaseModel
class RpcError(BaseModel):
code: int
message: str
data: Optional[Any]
data: Optional[Any] = None

View File

@ -8,5 +8,5 @@ from .rpc_error import RpcError
class RpcResponse(BaseModel):
jsonrpc: str
id: int
result: Union[Optional[Dict[str, Any]], None]
error: Optional[RpcError]
result: Union[Optional[Dict[str, Any]], None] = None
error: Optional[RpcError] = None

View File

@ -10,7 +10,7 @@ class SqlStatement(BaseModel):
id: str
statementHash: str
sql: str
columnMetadata: Optional[ColumnMetadata]
columnMetadata: Optional[ColumnMetadata] = None
userId: str
tags: Tags
createdAt: str

View File

@ -5,6 +5,6 @@ from pydantic import BaseModel
class Tags(BaseModel):
sdk_package: Optional[str]
sdk_version: Optional[str]
sdk_language: Optional[str]
sdk_package: Optional[str] = None
sdk_version: Optional[str] = None
sdk_language: Optional[str] = None

View File

@ -33,4 +33,4 @@ class CreateQueryRunRpcResult(BaseModel):
class CreateQueryRunRpcResponse(RpcResponse):
result: Union[CreateQueryRunRpcResult, None]
result: Union[CreateQueryRunRpcResult, None] = None

View File

@ -21,8 +21,8 @@ class GetQueryRunRpcRequest(RpcRequest):
# Response
class GetQueryRunRpcResult(BaseModel):
queryRun: QueryRun
redirectedToQueryRun: Optional[QueryRun]
redirectedToQueryRun: Optional[QueryRun] = None
class GetQueryRunRpcResponse(RpcResponse):
result: Union[GetQueryRunRpcResult, None]
result: Union[GetQueryRunRpcResult, None] = None

View File

@ -1,6 +1,6 @@
from typing import Any, Dict, List, Optional, Union
from pydantic import BaseModel
from pydantic import ConfigDict, BaseModel
from .core.page import Page
from .core.page_stats import PageStats
@ -22,9 +22,13 @@ class Filter(BaseModel):
like: Optional[Any] = None
in_: Optional[List[Any]] = None
notIn: Optional[List[Any]] = None
class Config:
fields = {"in_": "in"}
# TODO[pydantic]: The following keys were removed: `fields`.
# Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-config for more information.
model_config = ConfigDict(
alias_generator=None,
populate_by_name=True,
json_schema_extra={"fields": {"in_": "in"}}
)
def dict(self, *args, **kwargs) -> dict:
kwargs.setdefault("exclude_none", True) # Exclude keys with None values
@ -62,15 +66,15 @@ class GetQueryRunResultsRpcRequest(RpcRequest):
# Response
class GetQueryRunResultsRpcResult(BaseModel):
columnNames: Union[Optional[List[str]], None]
columnTypes: Union[Optional[List[str]], None]
rows: Union[List[Any], None]
page: Union[PageStats, None]
sql: Union[str, None]
format: Union[ResultFormat, None]
columnNames: Union[Optional[List[str]], None] = None
columnTypes: Union[Optional[List[str]], None] = None
rows: Union[List[Any], None] = None
page: Union[PageStats, None] = None
sql: Union[str, None] = None
format: Union[ResultFormat, None] = None
originalQueryRun: QueryRun
redirectedToQueryRun: Union[QueryRun, None]
redirectedToQueryRun: Union[QueryRun, None] = None
class GetQueryRunResultsRpcResponse(RpcResponse):
result: Union[GetQueryRunResultsRpcResult, None]
result: Union[GetQueryRunResultsRpcResult, None] = None

View File

@ -23,4 +23,4 @@ class GetSqlStatemetnResult(BaseModel):
class GetSqlStatementResponse(RpcResponse):
result: Union[GetSqlStatemetnResult, None]
result: Union[GetSqlStatemetnResult, None] = None

View File

@ -4,7 +4,7 @@ from pydantic import BaseModel, Field
class Query(BaseModel):
sql: str = Field(None, description="SQL query to execute")
sql: Optional[str] = Field(None, description="SQL query to execute")
ttl_minutes: Optional[int] = Field(
None, description="The number of minutes to cache the query results"
)
@ -21,8 +21,8 @@ class Query(BaseModel):
None,
description="An override on the cache. A value of true will Re-Execute the query.",
)
page_size: int = Field(None, description="The number of results to return per page")
page_number: int = Field(None, description="The page number to return")
page_size: Optional[int] = Field(None, description="The number of results to return per page")
page_number: Optional[int] = Field(None, description="The page number to return")
sdk_package: Optional[str] = Field(
None, description="The SDK package used for the query"
)

View File

@ -1,20 +1,21 @@
from typing import Optional
from pydantic import BaseModel, Field
class QueryDefaults(BaseModel):
ttl_minutes: int = Field(
ttl_minutes: Optional[int] = Field(
None, description="The number of minutes to cache the query results"
)
max_age_minutes: int = Field(
max_age_minutes: Optional[int] = Field(
None,
description="The max age of query results to accept before deciding to run a query again",
)
cached: bool = Field(False, description="Whether or not to cache the query results")
timeout_minutes: int = Field(
timeout_minutes: Optional[int] = Field(
None, description="The number of minutes to timeout the query"
)
retry_interval_seconds: float = Field(
retry_interval_seconds: Optional[float] = Field(
None, description="The number of seconds to wait before retrying the query"
)
page_size: int = Field(None, description="The number of results to return per page")
page_number: int = Field(None, description="The page number to return")
page_size: Optional[int] = Field(None, description="The number of results to return per page")
page_number: Optional[int] = Field(None, description="The page number to return")

View File

@ -10,7 +10,7 @@ class QueryResultSet(BaseModel):
query_id: Union[str, None] = Field(None, description="The server id of the query")
status: str = Field(
False, description="The status of the query (`PENDING`, `FINISHED`, `ERROR`)"
"PENDING", description="The status of the query (`PENDING`, `FINISHED`, `ERROR`)"
)
columns: Union[List[str], None] = Field(
None, description="The names of the columns in the result set"
@ -29,4 +29,4 @@ class QueryResultSet(BaseModel):
page: Union[PageStats, None] = Field(
None, description="Summary of page stats for this query result set"
)
error: Any
error: Any = None

View File

@ -1,40 +1,41 @@
from datetime import datetime
from typing import Optional
from pydantic import BaseModel, Field
class QueryRunStats(BaseModel):
started_at: datetime = Field(None, description="The start time of the query run.")
ended_at: datetime = Field(None, description="The end time of the query run.")
query_exec_started_at: datetime = Field(
started_at: Optional[datetime] = Field(None, description="The start time of the query run.")
ended_at: Optional[datetime] = Field(None, description="The end time of the query run.")
query_exec_started_at: Optional[datetime] = Field(
None, description="The start time of query execution."
)
query_exec_ended_at: datetime = Field(
query_exec_ended_at: Optional[datetime] = Field(
None, description="The end time of query execution."
)
streaming_started_at: datetime = Field(
streaming_started_at: Optional[datetime] = Field(
None, description="The start time of streaming query results."
)
streaming_ended_at: datetime = Field(
streaming_ended_at: Optional[datetime] = Field(
None, description="The end time of streaming query results."
)
elapsed_seconds: int = Field(
elapsed_seconds: Optional[int] = Field(
None,
description="The number of seconds elapsed between the start and end times.",
)
queued_seconds: int = Field(
queued_seconds: Optional[int] = Field(
None,
description="The number of seconds elapsed between when the query was created and when execution on the data source began.",
)
streaming_seconds: int = Field(
streaming_seconds: Optional[int] = Field(
None,
description="The number of seconds elapsed between when the query execution completed and results were fully streamed to Flipside's servers.",
)
query_exec_seconds: int = Field(
query_exec_seconds: Optional[int] = Field(
None,
description="The number of seconds elapsed between when the query execution started and when it completed on the data source.",
)
record_count: int = Field(
record_count: Optional[int] = Field(
None, description="The number of records returned by the query."
)
bytes: int = Field(None, description="The number of bytes returned by the query.")
bytes: Optional[int] = Field(None, description="The number of bytes returned by the query.")

View File

@ -6,4 +6,4 @@ from pydantic import BaseModel
class SleepConfig(BaseModel):
attempts: int
timeout_minutes: Union[int, float]
interval_seconds: Optional[float]
interval_seconds: Optional[float] = None

View File

@ -1,4 +1,6 @@
import json
import pytest
import requests_mock
from ....errors import (
ApiError,
@ -20,6 +22,12 @@ from ...utils.mock_data.get_sql_statement import get_sql_statement_response
SDK_VERSION = "1.0.2"
SDK_PACKAGE = "python"
# Add the fixture decorator
@pytest.fixture(autouse=True)
def requests_mock_fixture():
with requests_mock.Mocker() as m:
yield m
def get_rpc():
return RPC("https://test.com", "api_key")

View File

@ -1,4 +1,6 @@
import json
import pytest
import requests_mock
from ..errors.server_error import ServerError
from ..models import Query, QueryStatus
@ -14,6 +16,11 @@ from .utils.mock_data.create_query_run import create_query_run_response
from .utils.mock_data.get_query_results import get_query_results_response
from .utils.mock_data.get_query_run import get_query_run_response
@pytest.fixture(autouse=True)
def requests_mock_fixture():
with requests_mock.Mocker() as m:
yield m
"""
Test Defaults
"""