Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,13 @@ The `plotLaTeX` package is a recent project to make exporting Python data to a L
**[Stemplot](examples/StemPlot.ipynb)**

<p align="center">
<img src="images/example_stem.png" alt="Fig3" width="1000px">
<img src="images/example_stem.png" alt="Fig7" width="1000px">
</p>

**[Scatterplot](examples/ScatterPlot.ipynb)**

<p align="center">
<img src="images/example_scatter.png" alt="Fig8" width="1000px">
</p>


Expand Down
44 changes: 36 additions & 8 deletions examples/LinePlot.ipynb

Large diffs are not rendered by default.

209 changes: 209 additions & 0 deletions examples/ScatterPlot.ipynb

Large diffs are not rendered by default.

Binary file added images/example_scatter.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
11 changes: 10 additions & 1 deletion plotLaTeX/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,14 @@
from .box_plot import BoxPlot
from .bar_plot import Barplot, MultipleBars
from .stem_plot import StemPlot
from .scatter_plot import ScatterPlot

__all__ = ["LaTeXplot", "HistPlot", "BoxPlot", "Barplot", "MultipleBars", "StemPlot"]
__all__ = [
"LaTeXplot",
"HistPlot",
"BoxPlot",
"Barplot",
"MultipleBars",
"StemPlot",
"ScatterPlot",
]
110 changes: 110 additions & 0 deletions plotLaTeX/scatter_plot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import pandas as pd
import numpy as np


class ScatterPlot:
"""
A helper class to organize and export multiple datasets for scatter plotting.
It stores x–y matrix-like structures and can export data and produce LaTeX code.
"""

def __init__(self):
self.data_stack = dict()
self.y_name_list = list()
self.x_vals = None

def data_info(self):
"""Print current stored data."""
print("Current data stack:")
self.DF = pd.DataFrame(self.data_stack)
print(self.DF.head())

def add_xvals(self, x_vals, x_axs_name="x"):
"""Store the common x-values used for scatter plotting."""
print("Set x-values.")
self.x_axs_name = x_axs_name
self.x_vals = x_vals
self.n_x_vals = len(x_vals)

self.data_stack[x_axs_name] = x_vals
self.data_info()

def add_yvals(self, y_vals, y_name):
"""
Add an additional dependent variable (scatter y-values).
y_name must be unique.
"""
self.n_y_vals = len(y_vals)

if y_name in self.y_name_list:
print("Please use a different name to add more y-data.")
else:
self.y_name_list.append(y_name)
self.data_stack[y_name] = y_vals
print(f"Added {y_name} with {len(y_vals)} entries.")

self.data_info()

def export(self, path="", f_name="scatter_results.csv"):
"""
Export stored x–y data to a csv file.
If no x-values were provided, create a default index.
"""
if self.x_vals is None or len(self.x_vals) == 0:
print("No x-values found -> using index as x-axis.")
self.data_stack["x"] = np.arange(self.n_y_vals)

self.f_name = f_name

print("**Exporting scatter data**\n")
self.data_info()

pd.DataFrame(self.data_stack).to_csv(path + f_name, index=False)

print("\n***********")
print("LaTeX code for scatter plot:")
print("***********\n")
self.latex_code()

def latex_code(self, imports=False, caption="Caption of the scatter plot."):
"""
Produce pgfplots LaTeX code for scatter plots.
"""
if imports:
print("\tDon’t forget to import the packages:\n")
print(r"\usepackage{graphicx}")
print(r"\usepackage{tikz,pgfplots}")
print("\n*\t*********\n")

print(r"\begin{figure}[h]")
print(r" \centering")
print(r" \tikzstyle{every node}=[font=\footnotesize]")
print(r" \begin{tikzpicture}")
print(r" \begin{axis}[")
print(r" ylabel={y-label},")
print(r" xlabel={x-label},")
print(r" width=7.5cm,")
print(r" height=3cm,")
print(r" grid=both,")
print(r" legend columns=" + str(len(self.y_name_list)) + ",")
print(
r" legend style={at={(0,1.05)}, anchor=south west, draw=white!15!black},"
)
print(r" ]")

# scatter plot for every y-series
for yn in self.y_name_list:
print(r" \addplot+[only marks] ")
print(
f" table[x={self.x_axs_name},y={yn},col sep=comma]"
+ r"{"
+ self.f_name
+ r"};"
)
print(r" \addlegendentry{" + yn + r"};")

print(r" \end{axis}")
print(r" \end{tikzpicture}")
print(r" \caption{" + caption + "}")
print(r" \label{fig:" + caption.replace(" ", "_") + "}")
print(r"\end{figure}")