import polars as pl
from matplotlib import pyplot as plt
import os
import re

saveDir = r"TireDataCSV\SplitData"
folder = r"TireDataCSV\TransientTests"
splitFile = r"TireDataCSV\transientSplits.txt"

x = "ET"
y = "P"

save = True

files = [f for f in os.listdir(folder) if os.path.isfile(os.path.join(folder, f))]
files.sort(key = lambda f:int(re.match(r".+?([0-9]+)\.csv", f).group(1)))

with open(splitFile, "r") as f:
    splits = f.readlines()
splits = [s.strip() for s in splits if s.strip() != ""]
splits = [s if re.match(r"[A-Za-z]+", s) else float(s) for s in splits]

parseDensity = 200

def readFile(filename):
    df = pl.read_csv(filename, infer_schema_length=10000, ignore_errors=True)
    return df

def plot(x, y, size = 2000, title = ""):
    skip = max(1, int(len(x)/size))
    plt.plot(x[::skip], y[::skip], linestyle="", marker="o")
    plt.xlabel(x.name)
    plt.ylabel(y.name)
    plt.title(title)
    plt.grid()

def markSection(p1, p2, label=""):
    plt.axvline(x=p1, color='red', linestyle='--')
    plt.axvline(x=p2, color='red', linestyle='--')
    if label != "":
        x = (p1+p2)/2.0
        plt.annotate(label, xy=(x, 0), xytext=(x,0))

def getIndexes(t1, t2, t):
    t2 = min(t[-1], t2)
    i1 = int((t1/t[-1])*len(t))
    i2 = int((t2/t[-1])*len(t))
    return i1, i2

for filename in files:
    df = readFile(folder+"\\"+filename)
    
    plot(df[x], df[y], title = filename)

    name = ""
    m1 = 0
    m2 = 0
    for split in splits:
        if isinstance(split, str):
            name = split
            continue
        m1 = m2
        m2 = split
        if name != "":
            #Split and save df
            frame = df.filter(pl.col(x) < m2)
            frame = frame.filter(pl.col(x) >= m1)
            print("%s\\%s-%s.pq"%(saveDir, name, filename.removesuffix(".csv")))
            
            if save:
                frame.write_parquet("%s\\%s-%s.pq"%(saveDir, name, filename.removesuffix(".csv")))

            markSection(m1, m2, name)
            name = ""

    plt.show()