diff --git a/blitzortung/service/strike.py b/blitzortung/service/strike.py index aa10d16..df5d96c 100644 --- a/blitzortung/service/strike.py +++ b/blitzortung/service/strike.py @@ -27,6 +27,7 @@ from .general import TimingState, create_time_interval from .. import db, geom from ..data import Timestamp +from ..db.query import TimeInterval class StrikeState(TimingState): @@ -42,22 +43,21 @@ class StrikeQuery: def __init__(self, strike_query_builder: db.query_builder.Strike, strike_mapper: db.mapper.Strike): self.strike_query_builder = strike_query_builder self.strike_mapper = strike_mapper + self.id_order = db.query.Order('id') - def create(self, id_or_offset, minute_length, minute_offset, connection, statsd_client): - time_interval = create_time_interval(minute_length, minute_offset) + def create(self, id_or_offset, time_interval: TimeInterval, connection, statsd_client): state = StrikeState(statsd_client, Timestamp(time_interval.end)) id_interval = db.query.IdInterval(id_or_offset) if id_or_offset > 0 else None - order = db.query.Order('id') query = self.strike_query_builder.select_query(db.table.Strike.table_name, geom.Geometry.default_srid, - time_interval=time_interval, order=order, + time_interval=time_interval, order=self.id_order, id_interval=id_interval) strikes_result = connection.runQuery(str(query), query.get_parameters()) - strikes_result.addCallback(self.strike_build_results, state=state) + strikes_result.addCallback(self.build_result, state=state) return strikes_result, state - def strike_build_results(self, query_result, state): + def build_result(self, query_result, state): state.add_info_text("query %.03fs #%d" % (state.get_seconds(), len(query_result))) state.log_timing('strikes.query') @@ -89,12 +89,12 @@ def create_strikes(self, query_results): def combine_result(self, strikes_result, histogram_result, state): query = gatherResults([strikes_result, histogram_result], consumeErrors=True) - query.addCallback(self.compile_strikes_result, state=state) + query.addCallback(self.build_strikes_response, state=state) query.addErrback(log.err) return query @staticmethod - def compile_strikes_result(result, state): + def build_strikes_response(result, state): strikes_result = result[0] histogram_result = result[1] diff --git a/tests/service/test_strikes.py b/tests/service/test_strikes.py new file mode 100644 index 0000000..597abee --- /dev/null +++ b/tests/service/test_strikes.py @@ -0,0 +1,97 @@ +import datetime + +import pytest +from mock import Mock, call + +from blitzortung import geom, builder +from blitzortung.data import Timestamp +from blitzortung.service.strike import StrikeQuery, StrikeState +from tests.conftest import time_interval + + +@pytest.fixture +def query_builder(): + return Mock(name='query_builder') + + +@pytest.fixture +def connection(): + return Mock(name='connection') + + +@pytest.fixture +def statsd_client(): + return Mock(name='statsd_client') + + +@pytest.fixture +def strike_mapper(): + return Mock(name='strike_mapper') + + +class TestStrikeGridQuery: + + @pytest.fixture + def uut(self, query_builder, strike_mapper): + return StrikeQuery(query_builder, strike_mapper) + + @pytest.fixture + def state(self, statsd_client, time_interval, query_builder, connection, strike_mapper): + return StrikeState(statsd_client, Timestamp(time_interval.end)) + + def test_create(self, uut, time_interval, query_builder, connection, statsd_client, state): + result, state = uut.create(-600, time_interval, connection, statsd_client) + + query_builder.select_query.assert_called_once_with("strikes", geom.Geometry.default_srid, + time_interval=time_interval, + order=uut.id_order, id_interval=None) + + query = query_builder.select_query.return_value + assert connection.runQuery.call_args == call(str(query), query.get_parameters.return_value) + assert result == connection.runQuery.return_value + + assert result.addCallback.call_args.args == (uut.build_result,) + assert result.addCallback.call_args.kwargs["state"] == state + + def test_build_result(self, uut, state, time_interval, strike_mapper): + strike_mapper.create_object.return_value = \ + builder.Strike() \ + .set_id(1234) \ + .set_timestamp(time_interval.end - datetime.timedelta(seconds=60)) \ + .set_x(123.4) \ + .set_y(45.6) \ + .set_amplitude(1.2) \ + .set_lateral_error(3.4) \ + .set_altitude(234.1) \ + .set_station_count(12) \ + .build() + + query_result = [[567]] + + result = uut.build_result(query_result, state) + + strike_mapper.create_object.assert_called_with(query_result[0]) + + assert result == {'next': 568, 's': ((60, 123.4, 45.6, 234.1, 3.4, 1.2, 12),)} + + def test_build_empty_result(self, uut, state, time_interval, strike_mapper): + query_result = [] + + result = uut.build_result(query_result, state) + + strike_mapper.create_object.assert_not_called() + assert result == {'s': ()} + + def test_build_strikes_response(self, uut, state, time_interval): + grid_result = {'next': 123, 's': ((60, 123.4, 45.6, 234.1, 3.4, 1.2, 12),)} + histogram_result = [0, 0, 0, 0, 0, 1] + result = (grid_result, histogram_result) + + response = uut.build_strikes_response(result, state=state) + + assert response == { + 'h': [0, 0, 0, 0, 0, 1], + 'next': 123, + 's': ((60, 123.4, 45.6, 234.1, 3.4, 1.2, 12),), + 't': time_interval.end.strftime("%Y%m%dT%H:%M:%S"), + }