mirror of
https://github.com/tahnok/colmi_r02_client.git
synced 2026-02-06 10:47:28 +00:00
373 lines
11 KiB
Python
373 lines
11 KiB
Python
from datetime import datetime, timezone
|
|
import os
|
|
from pathlib import Path
|
|
from unittest.mock import create_autospec
|
|
|
|
from hypothesis import given, strategies as st
|
|
import pytest
|
|
from sqlalchemy import text, select, func, Dialect
|
|
from sqlalchemy.exc import IntegrityError
|
|
|
|
from colmi_r02_client.client import FullData
|
|
from colmi_r02_client import hr, steps
|
|
from colmi_r02_client.db import (
|
|
get_db_session,
|
|
create_or_find_ring,
|
|
full_sync,
|
|
Ring,
|
|
HeartRate,
|
|
SportDetail,
|
|
Sync,
|
|
get_last_sync,
|
|
DateTimeInUTC,
|
|
)
|
|
|
|
|
|
@pytest.fixture(name="address")
|
|
def get_address() -> str:
|
|
return "fake"
|
|
|
|
|
|
@pytest.fixture(name="empty_full_data")
|
|
def get_empty_full_data(address) -> FullData:
|
|
return FullData(address=address, heart_rates=[], sport_details=[])
|
|
|
|
|
|
def test_get_db_session_memory():
|
|
with get_db_session() as session:
|
|
assert session.scalars(text("SELECT 1")).one() == 1
|
|
|
|
|
|
def test_get_db_session_file(tmp_path: Path):
|
|
db_file = tmp_path / "test.sqlite"
|
|
assert not db_file.exists()
|
|
|
|
with get_db_session(db_file) as session:
|
|
assert session.scalars(text("SELECT 1")).one() == 1
|
|
|
|
assert db_file.exists()
|
|
|
|
|
|
def test_get_db_tables_exist():
|
|
with get_db_session() as session:
|
|
tables = set(session.scalars(text("SELECT name FROM sqlite_master WHERE type ='table'")).fetchall())
|
|
assert tables == {
|
|
"rings",
|
|
"syncs",
|
|
"heart_rates",
|
|
"sport_details",
|
|
}
|
|
|
|
|
|
def test_get_db_schema():
|
|
"""
|
|
I want to have each table schema in a spot that's
|
|
easy to update but also see. Maybe in a .sql file?
|
|
"""
|
|
schema_path = Path("tests/database_schema.sql")
|
|
expected = schema_path.read_text()
|
|
with get_db_session() as session:
|
|
actual = "\n\n".join(session.scalars(text("SELECT sql FROM sqlite_schema where type = 'table'")).fetchall())
|
|
if actual != expected:
|
|
if os.getenv("UPDATE_SCHEMA", None):
|
|
schema_path.write_text(actual)
|
|
pytest.fail("Test failed because we rewrote the schema file")
|
|
else:
|
|
assert actual == expected, "Schema mismatch, if this is expected rerun with UPDATE_SCHEMA=1"
|
|
|
|
|
|
def test_create_new_ring():
|
|
with get_db_session() as session:
|
|
address = "address"
|
|
ring = create_or_find_ring(session, address)
|
|
assert ring.address == address
|
|
|
|
|
|
def test_fetch_old_ring():
|
|
with get_db_session() as session:
|
|
address = "address"
|
|
new_ring = create_or_find_ring(session, address)
|
|
old_ring = create_or_find_ring(session, address)
|
|
assert old_ring == new_ring
|
|
assert old_ring.address == address
|
|
|
|
|
|
def test_ring_sync_id_required_for_heart_rate():
|
|
with get_db_session() as session, pytest.raises(IntegrityError):
|
|
session.add(HeartRate(reading=1, timestamp=datetime(2024, 1, 1, tzinfo=timezone.utc), ring_id=None, sync_id=None))
|
|
session.commit()
|
|
|
|
|
|
def test_sync_creates_ring(address, empty_full_data):
|
|
with get_db_session() as session:
|
|
full_sync(session, empty_full_data)
|
|
|
|
ring = session.scalars(select(Ring)).one()
|
|
assert address == ring.address
|
|
|
|
|
|
def test_sync_uses_existing_ring(address, empty_full_data):
|
|
with get_db_session() as session:
|
|
create_or_find_ring(session, address)
|
|
full_sync(session, empty_full_data)
|
|
|
|
assert session.scalars(func.count(Ring.ring_id)).one() == 1
|
|
|
|
|
|
def test_sync_creates_sync(address, empty_full_data):
|
|
with get_db_session() as session:
|
|
full_sync(session, empty_full_data)
|
|
|
|
sync_obj = session.scalars(select(Sync)).one()
|
|
|
|
assert sync_obj.ring.address == address
|
|
|
|
|
|
def test_sync_writes_heart_rates():
|
|
address = "fake"
|
|
hrl = hr.HeartRateLog(
|
|
heart_rates=[80] * 288,
|
|
timestamp=datetime(2024, 11, 11, 11, 11, tzinfo=timezone.utc),
|
|
size=24,
|
|
index=295,
|
|
range=5,
|
|
)
|
|
fd = FullData(address=address, heart_rates=[hrl], sport_details=[])
|
|
with get_db_session() as session:
|
|
full_sync(session, fd)
|
|
|
|
ring = session.scalars(select(Ring)).one()
|
|
logs = session.scalars(select(HeartRate)).all()
|
|
sync_obj = session.scalars(select(Sync)).one()
|
|
|
|
assert len(logs) == 288
|
|
assert logs[0].ring_id == ring.ring_id
|
|
assert logs[0].reading == 80
|
|
assert logs[0].timestamp == datetime(2024, 11, 11, 0, 0, tzinfo=timezone.utc)
|
|
assert logs[1].timestamp == datetime(2024, 11, 11, 0, 5, tzinfo=timezone.utc)
|
|
assert logs[0].sync_id == sync_obj.sync_id
|
|
|
|
|
|
def test_sync_writes_heart_rates_only_non_zero_heart_rates():
|
|
address = "fake"
|
|
hrl = hr.HeartRateLog(
|
|
heart_rates=[80] * 8 + [0] * 280,
|
|
timestamp=datetime(2024, 11, 11, 11, 11, tzinfo=timezone.utc),
|
|
size=24,
|
|
index=295,
|
|
range=5,
|
|
)
|
|
fd = FullData(address=address, heart_rates=[hrl], sport_details=[])
|
|
with get_db_session() as session:
|
|
full_sync(session, fd)
|
|
|
|
logs = session.scalars(select(HeartRate)).all()
|
|
|
|
assert len(logs) == 8
|
|
|
|
|
|
def test_sync_writes_heart_rates_once():
|
|
address = "fake"
|
|
hrl_1 = hr.HeartRateLog(
|
|
heart_rates=[80] * 8 + [0] * 280,
|
|
timestamp=datetime(2024, 11, 11, 11, 11, tzinfo=timezone.utc),
|
|
size=24,
|
|
index=295,
|
|
range=5,
|
|
)
|
|
fd_1 = FullData(address=address, heart_rates=[hrl_1], sport_details=[])
|
|
|
|
hrl_2 = hr.HeartRateLog(
|
|
heart_rates=[80] * 288,
|
|
timestamp=datetime(2024, 11, 11, 11, 11, tzinfo=timezone.utc),
|
|
size=24,
|
|
index=295,
|
|
range=5,
|
|
)
|
|
fd_2 = FullData(address=address, heart_rates=[hrl_2], sport_details=[])
|
|
with get_db_session() as session:
|
|
full_sync(session, fd_1)
|
|
full_sync(session, fd_2)
|
|
|
|
logs = session.scalars(select(HeartRate)).all()
|
|
|
|
assert len(logs) == 288
|
|
|
|
|
|
def test_sync_handles_inconsistent_data(caplog):
|
|
address = "fake"
|
|
hrl_1 = hr.HeartRateLog(
|
|
heart_rates=[80] * 288,
|
|
timestamp=datetime(2024, 11, 11, 11, 11, tzinfo=timezone.utc),
|
|
size=24,
|
|
index=295,
|
|
range=5,
|
|
)
|
|
fd_1 = FullData(address=address, heart_rates=[hrl_1], sport_details=[])
|
|
|
|
hrl_2 = hr.HeartRateLog(
|
|
heart_rates=[90] * 288,
|
|
timestamp=datetime(2024, 11, 11, 11, 11, tzinfo=timezone.utc),
|
|
size=24,
|
|
index=295,
|
|
range=5,
|
|
)
|
|
fd_2 = FullData(address=address, heart_rates=[hrl_2], sport_details=[])
|
|
with get_db_session() as session:
|
|
full_sync(session, fd_1)
|
|
full_sync(session, fd_2)
|
|
|
|
logs = session.scalars(select(HeartRate)).all()
|
|
|
|
assert len(logs) == 288
|
|
assert all(log.reading == 80 for log in logs)
|
|
assert "Inconsistent data detected! 2024-11-11 00:00:00+00:00 is 80 in db but got 90 from ring" in caplog.text
|
|
|
|
|
|
def test_full_sync_writes_sport_details():
|
|
address = "fake"
|
|
sd = steps.SportDetail(
|
|
year=2025,
|
|
month=1,
|
|
day=1,
|
|
time_index=0,
|
|
calories=4200,
|
|
steps=6969,
|
|
distance=1234,
|
|
)
|
|
fd = FullData(address=address, heart_rates=[], sport_details=[[sd]])
|
|
with get_db_session() as session:
|
|
full_sync(session, fd)
|
|
|
|
ring = session.scalars(select(Ring)).one()
|
|
sport_details = session.scalars(select(SportDetail)).all()
|
|
sync_obj = session.scalars(select(Sync)).one()
|
|
|
|
assert len(sport_details) == 1
|
|
assert sport_details[0].ring_id == ring.ring_id
|
|
assert sport_details[0].timestamp == datetime(2025, 1, 1, 0, 0, tzinfo=timezone.utc)
|
|
assert sport_details[0].sync_id == sync_obj.sync_id
|
|
|
|
|
|
def test_full_sync_no_sport_details():
|
|
address = "fake"
|
|
fd = FullData(address=address, heart_rates=[], sport_details=[steps.NoData(), steps.NoData()])
|
|
with get_db_session() as session:
|
|
full_sync(session, fd)
|
|
|
|
sport_details = session.scalars(select(SportDetail)).all()
|
|
|
|
assert len(sport_details) == 0
|
|
|
|
|
|
def test_get_last_sync_never():
|
|
with get_db_session() as session:
|
|
ring = Ring(address="foo")
|
|
session.add(ring)
|
|
assert get_last_sync(session, ring.address) is None
|
|
|
|
|
|
def test_get_sync_once():
|
|
with get_db_session() as session:
|
|
ring = Ring(address="foo")
|
|
timestamp = datetime(2024, 11, 11, 11, tzinfo=timezone.utc)
|
|
session.add(Sync(ring=ring, timestamp=timestamp))
|
|
session.commit()
|
|
assert get_last_sync(session, ring.address) == timestamp
|
|
|
|
|
|
def test_get_sync_many():
|
|
with get_db_session() as session:
|
|
ring = Ring(address="foo")
|
|
first = datetime(2024, 11, 11, 11, tzinfo=timezone.utc)
|
|
second = datetime(2024, 12, 12, 12, tzinfo=timezone.utc)
|
|
session.add(Sync(ring=ring, timestamp=first))
|
|
session.add(Sync(ring=ring, timestamp=second))
|
|
session.commit()
|
|
assert get_last_sync(session, ring.address) == second
|
|
|
|
|
|
def test_get_last_sync_two_rings():
|
|
with get_db_session() as session:
|
|
ring_1 = Ring(address="foo")
|
|
ring_2 = Ring(address="bar")
|
|
first = datetime(2024, 11, 11, 11, tzinfo=timezone.utc)
|
|
second = datetime(2024, 12, 12, 12, tzinfo=timezone.utc)
|
|
session.add(Sync(ring=ring_1, timestamp=first))
|
|
session.add(Sync(ring=ring_2, timestamp=second))
|
|
session.commit()
|
|
assert get_last_sync(session, ring_1.address) == first
|
|
assert get_last_sync(session, ring_2.address) == second
|
|
|
|
|
|
def test_datetimes_have_timezones():
|
|
with get_db_session() as session:
|
|
ring = Ring(address="foo")
|
|
timestamp = datetime(2024, 11, 11, 11, tzinfo=timezone.utc)
|
|
session.add(Sync(ring=ring, timestamp=timestamp))
|
|
session.commit()
|
|
assert get_last_sync(session, ring.address) == timestamp
|
|
assert timestamp.tzinfo is not None
|
|
|
|
|
|
def test_datetime_in_utc_process_bind_none():
|
|
dtiu = DateTimeInUTC()
|
|
dialect = create_autospec(Dialect)
|
|
|
|
assert dtiu.process_bind_param(None, dialect) is None
|
|
|
|
|
|
@pytest.mark.skip
|
|
@given(st.datetimes())
|
|
def test_datetime_in_utc_process_bind_no_tz(ts: datetime):
|
|
dtiu = DateTimeInUTC()
|
|
dialect = create_autospec(Dialect)
|
|
|
|
with pytest.raises(ValueError):
|
|
dtiu.process_bind_param(ts, dialect)
|
|
|
|
|
|
@pytest.mark.skip
|
|
@given(st.datetimes(timezones=st.timezones()))
|
|
def test_datetime_in_utc_process_bind_tz(ts: datetime):
|
|
dtiu = DateTimeInUTC()
|
|
dialect = create_autospec(Dialect)
|
|
|
|
result = dtiu.process_bind_param(ts, dialect)
|
|
|
|
assert result is not None
|
|
assert result.tzinfo == timezone.utc
|
|
assert ts.astimezone(timezone.utc) == result
|
|
|
|
|
|
def test_datetime_in_utc_process_result_none():
|
|
dtiu = DateTimeInUTC()
|
|
dialect = create_autospec(Dialect)
|
|
|
|
assert dtiu.process_result_value(None, dialect) is None
|
|
|
|
|
|
@pytest.mark.skip
|
|
@given(st.datetimes())
|
|
def test_datetime_in_utc_process_result_no_tz(ts: datetime):
|
|
dtiu = DateTimeInUTC()
|
|
dialect = create_autospec(Dialect)
|
|
|
|
result = dtiu.process_result_value(ts, dialect)
|
|
|
|
assert result is not None
|
|
assert result.tzinfo == timezone.utc
|
|
|
|
|
|
@pytest.mark.skip
|
|
@given(st.datetimes(timezones=st.timezones()))
|
|
def test_datetime_in_utc_process_tz(ts: datetime):
|
|
dtiu = DateTimeInUTC()
|
|
dialect = create_autospec(Dialect)
|
|
|
|
result = dtiu.process_result_value(ts, dialect)
|
|
|
|
assert result is not None
|
|
assert result.tzinfo == timezone.utc
|
|
assert ts.astimezone(timezone.utc) == result
|