import csv
import pickle
from tabulate import tabulate # https://pypi.org/project/tabulate/
import asciichartpy as chart
from math import nan
import sys
import io
from typing import Optional
from spyql.log import user_debug, user_error
from spyql.nulltype import NULL
from spyql.qdict import qdict
from spyql.query_result import QueryResult
from spyql.utils import is_row_collapsable
[docs]class Writer:
[docs] @staticmethod
def output_writers():
return {
"JSON": JSONWriter,
"ORJSON": ORJSONWriter,
"CSV": CSVWriter,
"MEMORY": MemoryWriter,
"SPY": SpyWriter,
"PRETTY": PrettyWriter,
"SQL": SQLWriter,
"PLOT": PlotWriter,
}
[docs] @staticmethod
def make_writer(to_clause: dict, output_options: Optional[dict] = None):
"""
Factory for making an output writer based on the parsed query
"""
try:
if not output_options:
output_options = {}
writer_name = to_clause
if not to_clause: # not TO clause, defaults to CSV
writer_name = "CSV"
return CSVWriter(**output_options)
elif isinstance(to_clause, dict): # there's an output data writer
writer_name = to_clause["name"]
writer = Writer.output_writers()[writer_name.upper()]
output_options.update(to_clause["kwargs"])
return writer(*to_clause["args"], **output_options)
else:
user_error(
f"Unknown writer '{writer_name}'",
SyntaxError("Error parsing TO statement"),
writer_name,
)
except TypeError as e:
user_error(f"Could not create '{writer_name}' writer", e)
def __init__(self, path=None, unbuffered=False):
user_debug(f"Loading writer {self.__class__.__name__}")
self.header = []
self.path = path
try:
self.outputfile = open(path, "w") if path else sys.stdout
if unbuffered:
self.outputfile = io.TextIOWrapper(
open(self.outputfile.fileno(), "wb", 0), write_through=True
)
except Exception as e:
user_error(f"Could not open output file {path}", e)
[docs] def close(self):
if self.path:
self.outputfile.close()
[docs] def writerow(self, row):
raise NotImplementedError
[docs] def writerows(self, rows):
for r in rows:
self.writerow(r)
[docs] def result(self) -> QueryResult:
"""Gets query result, in case of writing to memory"""
return None
[docs]class CSVWriter(Writer):
def __init__(self, path=None, unbuffered=False, header=True, **options):
super().__init__(path, unbuffered)
self.header_on = header
self.csv = csv.writer(self.outputfile, **options)
[docs] def writerow(self, row):
self.csv.writerow(row)
[docs] def writerows(self, rows):
self.csv.writerows(rows)
[docs]def json_default(x):
return None if x is NULL else str(x)
[docs]class JSONWriter(Writer):
def __init__(self, path=None, unbuffered=False, default=json_default, **options):
import json
super().__init__(path, unbuffered)
self.encoder = json.JSONEncoder(default=default, **options)
self.encoder.encode({"a": 1}) # test options
[docs] def writerow(self, row):
obj = (
row[0]
if is_row_collapsable(row, self.header)
else dict(zip(self.header, row))
)
self.outputfile.write(self.encoder.encode(obj) + "\n")
[docs]class ORJSONWriter(Writer):
def __init__(self, path=None, unbuffered=False, default=json_default, option=0):
super().__init__(path, unbuffered)
try:
# currently the only supported JSON external library is orjson,
# which must be installed separately (it is not a dependency)
import orjson
self.orjson = orjson
except ModuleNotFoundError as e:
user_error(
"`orjson` module not found. You might need to install it",
e,
"pip3 install orjson",
)
self.default = default
self.option = option | orjson.OPT_APPEND_NEWLINE
[docs] def writerow(self, row):
# TODO optimization: only call `is_row_collapsable` in the 1st row?
obj = (
row[0]
if is_row_collapsable(row, self.header)
else dict(zip(self.header, row))
)
self.outputfile.buffer.write(
self.orjson.dumps(obj, default=self.default, option=self.option)
)
[docs]class CollectWriter(Writer):
"""
Abstract writer that collects all records into a (in-memory) list and dumps all
the output records at the end.
Child classes must implement the `dumprows` method.
"""
def __init__(self, path=None, unbuffered=False):
super().__init__(path, unbuffered)
self.all_rows = [] # needs to store output in memory
[docs] def writerow(self, row):
self.all_rows.append(self.transformrow(row)) # accumulates
[docs] def dumprows(self, rows):
raise NotImplementedError
[docs] def flush(self):
if self.all_rows:
self.dumprows(self.all_rows) # dumps
[docs]class MemoryWriter(CollectWriter):
[docs] def result(self):
return QueryResult(self.all_rows, self.header)
[docs] def dumprows(self, rows):
pass
[docs]class PrettyWriter(CollectWriter):
def __init__(self, path=None, unbuffered=False, header=True, **options):
super().__init__(path, unbuffered)
tabulate([[1, 2, 3]], **options) # test options
self.header_on = header
self.options = options
[docs] def dumprows(self, rows):
# TODO handle default tablefmt
self.outputfile.write(
tabulate(
rows,
self.header if self.header_on else [],
**self.options,
)
)
self.outputfile.write("\n")
[docs]class PlotWriter(CollectWriter):
def __init__(self, path=None, unbuffered=False, header=True, height=20):
super().__init__(path, unbuffered)
self.header_on = header
self.height = height
[docs] def dumprows(self, rows):
colors = [
chart.cyan,
chart.red,
chart.magenta,
chart.lightgray,
chart.green,
chart.blue,
]
config = {"height": self.height, "colors": colors}
# first transpose rows into cols
cols = list(map(list, zip(*rows)))
self.outputfile.write(chart.plot(cols, config))
if self.header and self.header_on:
self.outputfile.write("\n\nLegend: ")
for i in range(len(self.header)):
self.outputfile.write(
"\t"
+ colors[i % len(colors)]
+ "─── "
+ self.header[i]
+ chart.reset
+ " "
)
self.outputfile.write("\n")
[docs]class SpyWriter(Writer):
def __init__(self, path=None, unbuffered=False):
super().__init__(path, unbuffered)
[docs] @staticmethod
def pack(row):
return pickle.dumps(row).hex() + "\n"
[docs] def writerow(self, row):
self.outputfile.write(SpyWriter.pack(row))
[docs]class SQLWriter(Writer):
def __init__(
self, path=None, unbuffered=False, chunk_size=1000, table="table_name"
):
super().__init__(path, unbuffered)
self.chunk_size = chunk_size
self.table_name = table
self.chunk = []
[docs] def writerow(self, row):
self.chunk.append(
"({})".format(
",".join(
[
str(v)
if isinstance(v, int) or isinstance(v, float)
else (
"NULL"
if v is NULL or v is None
else "'{}'".format(str(v).replace("'", "''"))
)
for v in row
]
)
)
)
if len(self.chunk) >= self.chunk_size:
self.writestatement()
[docs] def writestatement(self):
self.outputfile.write(self.statement.format(",".join(self.chunk)))
self.chunk = []
[docs] def flush(self):
if self.chunk: # write leftovers...
self.writestatement()