diff --git a/.vscode/launch.json b/.vscode/launch.json
index 23e5203..bca81a9 100644
--- a/.vscode/launch.json
+++ b/.vscode/launch.json
@@ -4,6 +4,7 @@
// 欲了解更多信息,请访问: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"configurations": [
+
{
"name": "Python: Web",
"type": "python",
@@ -61,5 +62,14 @@
"-m", ".\\ppg2mel\\saved_models\\best_loss_step_304000.pth", "--wav_dir", ".\\wavs\\input", "--ref_wav_path", ".\\wavs\\pkq.mp3", "-o", ".\\wavs\\output\\"
]
},
+ {
+ "name": "GUI",
+ "type": "python",
+ "request": "launch",
+ "program": "mkgui\\base\\_cli.py",
+ "console": "integratedTerminal",
+ "args": ["launch-ui", "mkgui.app:mocking_bird"
+ ]
+ },
]
}
diff --git a/gui/___init__.py b/mkgui/__init__.py
similarity index 100%
rename from gui/___init__.py
rename to mkgui/__init__.py
diff --git a/gui/app.py b/mkgui/app.py
similarity index 76%
rename from gui/app.py
rename to mkgui/app.py
index d753126..8387a60 100644
--- a/gui/app.py
+++ b/mkgui/app.py
@@ -8,9 +8,11 @@ import librosa
from scipy.io.wavfile import write
import re
import numpy as np
-from opyrator.components.types import FileContent
+from mkgui.base.components.types import FileContent
from vocoder.hifigan import inference as gan_vocoder
from synthesizer.inference import Synthesizer
+from typing import Any
+import matplotlib.pyplot as plt
# Constants
AUDIO_SAMPLES_DIR = 'samples\\'
@@ -40,7 +42,7 @@ class Input(BaseModel):
..., alias="输入语音(本地wav)",
description="选择本地语音文件."
)
- upload_audio_file: FileContent = Field(..., alias="或上传语音",
+ upload_audio_file: FileContent = Field(default=None, alias="或上传语音",
description="拖拽或点击上传.", mime_type="audio/wav")
encoder: encoders = Field(
..., alias="编码模型",
@@ -58,17 +60,30 @@ class Input(BaseModel):
..., example="欢迎使用工具箱, 现已支持中文输入!", alias="输出文本内容"
)
+class AudioEntity(BaseModel):
+ content: bytes
+ mel: Any
+
class Output(BaseModel):
- result_file: FileContent = Field(
- ...,
- mime_type="audio/wav",
- description="输出音频",
- )
- source_file: FileContent = Field(
- ...,
- mime_type="audio/wav",
- description="原始音频.",
- )
+ __root__: tuple[AudioEntity, AudioEntity]
+
+ def render_output_ui(self, streamlit_app, input) -> None: # type: ignore
+ """Custom output UI.
+ If this method is implmeneted, it will be used instead of the default Output UI renderer.
+ """
+ src, result = self.__root__
+ streamlit_app.subheader("Result Audio")
+ streamlit_app.audio(result.content, format="audio/wav")
+
+ fig, ax = plt.subplots()
+ ax.imshow(src.mel, aspect="equal", interpolation="none")
+ ax.set_title("mel spectrogram(Source Audio)")
+ streamlit_app.pyplot(fig)
+ fig, ax = plt.subplots()
+ ax.imshow(result.mel, aspect="equal", interpolation="none")
+ ax.set_title("mel spectrogram(Result Audio)")
+ streamlit_app.pyplot(fig)
+
def mocking_bird(input: Input) -> Output:
"""欢迎使用MockingBird Web 2"""
@@ -78,7 +93,7 @@ def mocking_bird(input: Input) -> Output:
gan_vocoder.load_model(Path(input.vocoder.value))
# load file
- if input.upload_audio_file != NULL:
+ if input.upload_audio_file != None:
with open(TEMP_SOURCE_AUDIO, "w+b") as f:
f.write(input.upload_audio_file.as_bytes())
f.seek(0)
@@ -87,6 +102,8 @@ def mocking_bird(input: Input) -> Output:
wav, sample_rate = librosa.load(input.local_audio_file.value)
write(TEMP_SOURCE_AUDIO, sample_rate, wav) #Make sure we get the correct wav
+ source_spec = Synthesizer.make_spectrogram(wav)
+
# preprocess
encoder_wav = encoder.preprocess_wav(wav, sample_rate)
embed, _, _ = encoder.embed_utterance(encoder_wav, return_partials=True)
@@ -114,4 +131,4 @@ def mocking_bird(input: Input) -> Output:
source_file = f.read()
with open(TEMP_RESULT_AUDIO, "rb") as f:
result_file = f.read()
- return Output(source_file=source_file, result_file=result_file)
\ No newline at end of file
+ return Output(__root__=(AudioEntity(content=source_file, mel=source_spec), AudioEntity(content=result_file, mel=spec)))
\ No newline at end of file
diff --git a/mkgui/base/__init__.py b/mkgui/base/__init__.py
new file mode 100644
index 0000000..6905fa0
--- /dev/null
+++ b/mkgui/base/__init__.py
@@ -0,0 +1,2 @@
+
+from .core import Opyrator
diff --git a/mkgui/base/_cli.py b/mkgui/base/_cli.py
new file mode 100644
index 0000000..3706271
--- /dev/null
+++ b/mkgui/base/_cli.py
@@ -0,0 +1,64 @@
+"""Command line interface."""
+
+import os
+import sys
+
+import typer
+from pydantic.error_wrappers import ValidationError
+
+cli = typer.Typer()
+
+@cli.command()
+def launch_ui(opyrator: str, port: int = typer.Option(8051, "--port", "-p")) -> None:
+ """Start a graphical UI server for the opyrator.
+
+ The UI is auto-generated from the input- and output-schema of the given function.
+ """
+ # Add the current working directory to the sys path
+ # This is required to resolve the opyrator path
+ sys.path.append(os.getcwd())
+
+ from mkgui.base.ui.streamlit_ui import launch_ui
+ launch_ui(opyrator, port)
+
+
+@cli.command()
+def launch_api(
+ opyrator: str,
+ port: int = typer.Option(8080, "--port", "-p"),
+ host: str = typer.Option("0.0.0.0", "--host", "-h"),
+) -> None:
+ """Start a HTTP API server for the opyrator.
+
+ This will launch a FastAPI server based on the OpenAPI standard and with an automatic interactive documentation.
+ """
+ # Add the current working directory to the sys path
+ # This is required to resolve the opyrator path
+ sys.path.append(os.getcwd())
+
+ from mkgui.base.api.fastapi_app import launch_api # type: ignore
+
+ launch_api(opyrator, port, host)
+
+
+@cli.command()
+def call(opyrator: str, input_data: str) -> None:
+ """Execute the opyrator from command line."""
+ # Add the current working directory to the sys path
+ # This is required to resolve the opyrator path
+ sys.path.append(os.getcwd())
+
+ try:
+ from mkgui.base import Opyrator
+
+ output = Opyrator(opyrator)(input_data)
+ if output:
+ typer.echo(output.json(indent=4))
+ else:
+ typer.echo("Nothing returned!")
+ except ValidationError as ex:
+ typer.secho(str(ex), fg=typer.colors.RED, err=True)
+
+
+if __name__ == "__main__":
+ cli()
\ No newline at end of file
diff --git a/mkgui/base/api/__init__.py b/mkgui/base/api/__init__.py
new file mode 100644
index 0000000..a0c4102
--- /dev/null
+++ b/mkgui/base/api/__init__.py
@@ -0,0 +1 @@
+from .fastapi_app import create_api
diff --git a/mkgui/base/api/fastapi_app.py b/mkgui/base/api/fastapi_app.py
new file mode 100644
index 0000000..234b6c5
--- /dev/null
+++ b/mkgui/base/api/fastapi_app.py
@@ -0,0 +1,69 @@
+from typing import Any, Dict
+
+from fastapi import FastAPI, status
+from fastapi.middleware.cors import CORSMiddleware
+from starlette.responses import RedirectResponse
+
+from mkgui.base import Opyrator
+from mkgui.base.api.fastapi_utils import patch_fastapi
+
+
+def launch_api(opyrator_path: str, port: int = 8501, host: str = "0.0.0.0") -> None:
+ import uvicorn
+
+ from mkgui.base import Opyrator
+ from mkgui.base.api import create_api
+
+ app = create_api(Opyrator(opyrator_path))
+ uvicorn.run(app, host=host, port=port, log_level="info")
+
+
+def create_api(opyrator: Opyrator) -> FastAPI:
+
+ title = opyrator.name
+ if "opyrator" not in opyrator.name.lower():
+ title += " - Opyrator"
+
+ # TODO what about version?
+ app = FastAPI(title=title, description=opyrator.description)
+
+ patch_fastapi(app)
+
+ app.add_middleware(
+ CORSMiddleware,
+ allow_origins=["*"],
+ allow_credentials=True,
+ allow_methods=["*"],
+ allow_headers=["*"],
+ )
+
+ @app.post(
+ "/call",
+ operation_id="call",
+ response_model=opyrator.output_type,
+ # response_model_exclude_unset=True,
+ summary="Execute the opyrator.",
+ status_code=status.HTTP_200_OK,
+ )
+ def call(input: opyrator.input_type) -> Any: # type: ignore
+ """Executes this opyrator."""
+ return opyrator(input)
+
+ @app.get(
+ "/info",
+ operation_id="info",
+ response_model=Dict,
+ # response_model_exclude_unset=True,
+ summary="Get info metadata.",
+ status_code=status.HTTP_200_OK,
+ )
+ def info() -> Any: # type: ignore
+ """Returns informational metadata about this Opyrator."""
+ return {}
+
+ # Redirect to docs
+ @app.get("/", include_in_schema=False)
+ def root() -> Any:
+ return RedirectResponse("./docs")
+
+ return app
diff --git a/mkgui/base/api/fastapi_utils.py b/mkgui/base/api/fastapi_utils.py
new file mode 100644
index 0000000..adf582a
--- /dev/null
+++ b/mkgui/base/api/fastapi_utils.py
@@ -0,0 +1,102 @@
+"""Collection of utilities for FastAPI apps."""
+
+import inspect
+from typing import Any, Type
+
+from fastapi import FastAPI, Form
+from pydantic import BaseModel
+
+
+def as_form(cls: Type[BaseModel]) -> Any:
+ """Adds an as_form class method to decorated models.
+
+ The as_form class method can be used with FastAPI endpoints
+ """
+ new_params = [
+ inspect.Parameter(
+ field.alias,
+ inspect.Parameter.POSITIONAL_ONLY,
+ default=(Form(field.default) if not field.required else Form(...)),
+ )
+ for field in cls.__fields__.values()
+ ]
+
+ async def _as_form(**data): # type: ignore
+ return cls(**data)
+
+ sig = inspect.signature(_as_form)
+ sig = sig.replace(parameters=new_params)
+ _as_form.__signature__ = sig # type: ignore
+ setattr(cls, "as_form", _as_form)
+ return cls
+
+
+def patch_fastapi(app: FastAPI) -> None:
+ """Patch function to allow relative url resolution.
+
+ This patch is required to make fastapi fully functional with a relative url path.
+ This code snippet can be copy-pasted to any Fastapi application.
+ """
+ from fastapi.openapi.docs import get_redoc_html, get_swagger_ui_html
+ from starlette.requests import Request
+ from starlette.responses import HTMLResponse
+
+ async def redoc_ui_html(req: Request) -> HTMLResponse:
+ assert app.openapi_url is not None
+ redoc_ui = get_redoc_html(
+ openapi_url="./" + app.openapi_url.lstrip("/"),
+ title=app.title + " - Redoc UI",
+ )
+
+ return HTMLResponse(redoc_ui.body.decode("utf-8"))
+
+ async def swagger_ui_html(req: Request) -> HTMLResponse:
+ assert app.openapi_url is not None
+ swagger_ui = get_swagger_ui_html(
+ openapi_url="./" + app.openapi_url.lstrip("/"),
+ title=app.title + " - Swagger UI",
+ oauth2_redirect_url=app.swagger_ui_oauth2_redirect_url,
+ )
+
+ # insert request interceptor to have all request run on relativ path
+ request_interceptor = (
+ "requestInterceptor: (e) => {"
+ "\n\t\t\tvar url = window.location.origin + window.location.pathname"
+ '\n\t\t\turl = url.substring( 0, url.lastIndexOf( "/" ) + 1);'
+ "\n\t\t\turl = e.url.replace(/http(s)?:\/\/[^/]*\//i, url);" # noqa: W605
+ "\n\t\t\te.contextUrl = url"
+ "\n\t\t\te.url = url"
+ "\n\t\t\treturn e;}"
+ )
+
+ return HTMLResponse(
+ swagger_ui.body.decode("utf-8").replace(
+ "dom_id: '#swagger-ui',",
+ "dom_id: '#swagger-ui',\n\t\t" + request_interceptor + ",",
+ )
+ )
+
+ # remove old docs route and add our patched route
+ routes_new = []
+ for app_route in app.routes:
+ if app_route.path == "/docs": # type: ignore
+ continue
+
+ if app_route.path == "/redoc": # type: ignore
+ continue
+
+ routes_new.append(app_route)
+
+ app.router.routes = routes_new
+
+ assert app.docs_url is not None
+ app.add_route(app.docs_url, swagger_ui_html, include_in_schema=False)
+ assert app.redoc_url is not None
+ app.add_route(app.redoc_url, redoc_ui_html, include_in_schema=False)
+
+ # Make graphql realtive
+ from starlette import graphql
+
+ graphql.GRAPHIQL = graphql.GRAPHIQL.replace(
+ "({{REQUEST_PATH}}", '("." + {{REQUEST_PATH}}'
+ )
diff --git a/mkgui/base/components/__init__.py b/mkgui/base/components/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/mkgui/base/components/outputs.py b/mkgui/base/components/outputs.py
new file mode 100644
index 0000000..f4859c6
--- /dev/null
+++ b/mkgui/base/components/outputs.py
@@ -0,0 +1,43 @@
+from typing import List
+
+from pydantic import BaseModel
+
+
+class ScoredLabel(BaseModel):
+ label: str
+ score: float
+
+
+class ClassificationOutput(BaseModel):
+ __root__: List[ScoredLabel]
+
+ def __iter__(self): # type: ignore
+ return iter(self.__root__)
+
+ def __getitem__(self, item): # type: ignore
+ return self.__root__[item]
+
+ def render_output_ui(self, streamlit) -> None: # type: ignore
+ import plotly.express as px
+
+ sorted_predictions = sorted(
+ [prediction.dict() for prediction in self.__root__],
+ key=lambda k: k["score"],
+ )
+
+ num_labels = len(sorted_predictions)
+ if len(sorted_predictions) > 10:
+ num_labels = streamlit.slider(
+ "Maximum labels to show: ",
+ min_value=1,
+ max_value=len(sorted_predictions),
+ value=len(sorted_predictions),
+ )
+ fig = px.bar(
+ sorted_predictions[len(sorted_predictions) - num_labels :],
+ x="score",
+ y="label",
+ orientation="h",
+ )
+ streamlit.plotly_chart(fig, use_container_width=True)
+ # fig.show()
diff --git a/mkgui/base/components/types.py b/mkgui/base/components/types.py
new file mode 100644
index 0000000..e18e267
--- /dev/null
+++ b/mkgui/base/components/types.py
@@ -0,0 +1,29 @@
+import base64
+from typing import Any, Dict
+
+
+class FileContent(str):
+ def as_bytes(self) -> bytes:
+ return base64.b64decode(self, validate=True)
+
+ def as_str(self) -> str:
+ return self.as_bytes().decode()
+
+ @classmethod
+ def __modify_schema__(cls, field_schema: Dict[str, Any]) -> None:
+ field_schema.update(format="byte")
+
+ @classmethod
+ def __get_validators__(cls) -> Any: # type: ignore
+ yield cls.validate
+
+ @classmethod
+ def validate(cls, value: Any) -> "FileContent":
+ if isinstance(value, FileContent):
+ return value
+ elif isinstance(value, str):
+ return FileContent(value)
+ elif isinstance(value, (bytes, bytearray, memoryview)):
+ return FileContent(base64.b64encode(value).decode())
+ else:
+ raise Exception("Wrong type")
diff --git a/mkgui/base/core.py b/mkgui/base/core.py
new file mode 100644
index 0000000..069d352
--- /dev/null
+++ b/mkgui/base/core.py
@@ -0,0 +1,204 @@
+import importlib
+import inspect
+import re
+from typing import Any, Callable, Type, Union, get_type_hints
+
+from pydantic import BaseModel, parse_raw_as
+from pydantic.tools import parse_obj_as
+
+
+def name_to_title(name: str) -> str:
+ """Converts a camelCase or snake_case name to title case."""
+ # If camelCase -> convert to snake case
+ name = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name)
+ name = re.sub("([a-z0-9])([A-Z])", r"\1_\2", name).lower()
+ # Convert to title case
+ return name.replace("_", " ").strip().title()
+
+
+def is_compatible_type(type: Type) -> bool:
+ """Returns `True` if the type is opyrator-compatible."""
+ try:
+ if issubclass(type, BaseModel):
+ return True
+ except Exception:
+ pass
+
+ try:
+ # valid list type
+ if type.__origin__ is list and issubclass(type.__args__[0], BaseModel):
+ return True
+ except Exception:
+ pass
+
+ return False
+
+
+def get_input_type(func: Callable) -> Type:
+ """Returns the input type of a given function (callable).
+
+ Args:
+ func: The function for which to get the input type.
+
+ Raises:
+ ValueError: If the function does not have a valid input type annotation.
+ """
+ type_hints = get_type_hints(func)
+
+ if "input" not in type_hints:
+ raise ValueError(
+ "The callable MUST have a parameter with the name `input` with typing annotation. "
+ "For example: `def my_opyrator(input: InputModel) -> OutputModel:`."
+ )
+
+ input_type = type_hints["input"]
+
+ if not is_compatible_type(input_type):
+ raise ValueError(
+ "The `input` parameter MUST be a subclass of the Pydantic BaseModel or a list of Pydantic models."
+ )
+
+ # TODO: return warning if more than one input parameters
+
+ return input_type
+
+
+def get_output_type(func: Callable) -> Type:
+ """Returns the output type of a given function (callable).
+
+ Args:
+ func: The function for which to get the output type.
+
+ Raises:
+ ValueError: If the function does not have a valid output type annotation.
+ """
+ type_hints = get_type_hints(func)
+ if "return" not in type_hints:
+ raise ValueError(
+ "The return type of the callable MUST be annotated with type hints."
+ "For example: `def my_opyrator(input: InputModel) -> OutputModel:`."
+ )
+
+ output_type = type_hints["return"]
+
+ if not is_compatible_type(output_type):
+ raise ValueError(
+ "The return value MUST be a subclass of the Pydantic BaseModel or a list of Pydantic models."
+ )
+
+ return output_type
+
+
+def get_callable(import_string: str) -> Callable:
+ """Import a callable from an string."""
+ callable_seperator = ":"
+ if callable_seperator not in import_string:
+ # Use dot as seperator
+ callable_seperator = "."
+
+ if callable_seperator not in import_string:
+ raise ValueError("The callable path MUST specify the function. ")
+
+ mod_name, callable_name = import_string.rsplit(callable_seperator, 1)
+ mod = importlib.import_module(mod_name)
+ return getattr(mod, callable_name)
+
+
+class Opyrator:
+ def __init__(self, func: Union[Callable, str]) -> None:
+ if isinstance(func, str):
+ # Try to load the function from a string notion
+ self.function = get_callable(func)
+ else:
+ self.function = func
+
+ self._name = "Opyrator"
+ self._description = ""
+ self._input_type = None
+ self._output_type = None
+
+ if not callable(self.function):
+ raise ValueError("The provided function parameters is not a callable.")
+
+ if inspect.isclass(self.function):
+ raise ValueError(
+ "The provided callable is an uninitialized Class. This is not allowed."
+ )
+
+ if inspect.isfunction(self.function):
+ # The provided callable is a function
+ self._input_type = get_input_type(self.function)
+ self._output_type = get_output_type(self.function)
+
+ try:
+ # Get name
+ self._name = name_to_title(self.function.__name__)
+ except Exception:
+ pass
+
+ try:
+ # Get description from function
+ doc_string = inspect.getdoc(self.function)
+ if doc_string:
+ self._description = doc_string
+ except Exception:
+ pass
+ elif hasattr(self.function, "__call__"):
+ # The provided callable is a function
+ self._input_type = get_input_type(self.function.__call__) # type: ignore
+ self._output_type = get_output_type(self.function.__call__) # type: ignore
+
+ try:
+ # Get name
+ self._name = name_to_title(type(self.function).__name__)
+ except Exception:
+ pass
+
+ try:
+ # Get description from
+ doc_string = inspect.getdoc(self.function.__call__) # type: ignore
+ if doc_string:
+ self._description = doc_string
+
+ if (
+ not self._description
+ or self._description == "Call self as a function."
+ ):
+ # Get docstring from class instead of __call__ function
+ doc_string = inspect.getdoc(self.function)
+ if doc_string:
+ self._description = doc_string
+ except Exception:
+ pass
+ else:
+ raise ValueError("Unknown callable type.")
+
+ @property
+ def name(self) -> str:
+ return self._name
+
+ @property
+ def description(self) -> str:
+ return self._description
+
+ @property
+ def input_type(self) -> Any:
+ return self._input_type
+
+ @property
+ def output_type(self) -> Any:
+ return self._output_type
+
+ def __call__(self, input: Any, **kwargs: Any) -> Any:
+
+ input_obj = input
+
+ if isinstance(input, str):
+ # Allow json input
+ input_obj = parse_raw_as(self.input_type, input)
+
+ if isinstance(input, dict):
+ # Allow dict input
+ input_obj = parse_obj_as(self.input_type, input)
+
+ return self.function(input_obj, **kwargs)
diff --git a/mkgui/base/ui/__init__.py b/mkgui/base/ui/__init__.py
new file mode 100644
index 0000000..593b254
--- /dev/null
+++ b/mkgui/base/ui/__init__.py
@@ -0,0 +1 @@
+from .streamlit_ui import render_streamlit_ui
diff --git a/mkgui/base/ui/schema_utils.py b/mkgui/base/ui/schema_utils.py
new file mode 100644
index 0000000..4e2e0b8
--- /dev/null
+++ b/mkgui/base/ui/schema_utils.py
@@ -0,0 +1,124 @@
+from typing import Dict
+
+
+def resolve_reference(reference: str, references: Dict) -> Dict:
+ return references[reference.split("/")[-1]]
+
+
+def get_single_reference_item(property: Dict, references: Dict) -> Dict:
+ # Ref can either be directly in the properties or the first element of allOf
+ reference = property.get("$ref")
+ if reference is None:
+ reference = property["allOf"][0]["$ref"]
+ return resolve_reference(reference, references)
+
+
+def is_single_string_property(property: Dict) -> bool:
+ return property.get("type") == "string"
+
+
+def is_single_datetime_property(property: Dict) -> bool:
+ if property.get("type") != "string":
+ return False
+ return property.get("format") in ["date-time", "time", "date"]
+
+
+def is_single_boolean_property(property: Dict) -> bool:
+ return property.get("type") == "boolean"
+
+
+def is_single_number_property(property: Dict) -> bool:
+ return property.get("type") in ["integer", "number"]
+
+
+def is_single_file_property(property: Dict) -> bool:
+ if property.get("type") != "string":
+ return False
+ # TODO: binary?
+ return property.get("format") == "byte"
+
+
+def is_multi_enum_property(property: Dict, references: Dict) -> bool:
+ if property.get("type") != "array":
+ return False
+
+ if property.get("uniqueItems") is not True:
+ # Only relevant if it is a set or other datastructures with unique items
+ return False
+
+ try:
+ _ = resolve_reference(property["items"]["$ref"], references)["enum"]
+ return True
+ except Exception:
+ return False
+
+
+def is_single_enum_property(property: Dict, references: Dict) -> bool:
+ try:
+ _ = get_single_reference_item(property, references)["enum"]
+ return True
+ except Exception:
+ return False
+
+
+def is_single_dict_property(property: Dict) -> bool:
+ if property.get("type") != "object":
+ return False
+ return "additionalProperties" in property
+
+
+def is_single_reference(property: Dict) -> bool:
+ if property.get("type") is not None:
+ return False
+
+ return bool(property.get("$ref"))
+
+
+def is_multi_file_property(property: Dict) -> bool:
+ if property.get("type") != "array":
+ return False
+
+ if property.get("items") is None:
+ return False
+
+ try:
+ # TODO: binary
+ return property["items"]["format"] == "byte"
+ except Exception:
+ return False
+
+
+def is_single_object(property: Dict, references: Dict) -> bool:
+ try:
+ object_reference = get_single_reference_item(property, references)
+ if object_reference["type"] != "object":
+ return False
+ return "properties" in object_reference
+ except Exception:
+ return False
+
+
+def is_property_list(property: Dict) -> bool:
+ if property.get("type") != "array":
+ return False
+
+ if property.get("items") is None:
+ return False
+
+ try:
+ return property["items"]["type"] in ["string", "number", "integer"]
+ except Exception:
+ return False
+
+
+def is_object_list_property(property: Dict, references: Dict) -> bool:
+ if property.get("type") != "array":
+ return False
+
+ try:
+ object_reference = resolve_reference(property["items"]["$ref"], references)
+ if object_reference["type"] != "object":
+ return False
+ return "properties" in object_reference
+ except Exception:
+ return False
diff --git a/mkgui/base/ui/streamlit_ui.py b/mkgui/base/ui/streamlit_ui.py
new file mode 100644
index 0000000..9fe749f
--- /dev/null
+++ b/mkgui/base/ui/streamlit_ui.py
@@ -0,0 +1,860 @@
+import datetime
+import inspect
+import mimetypes
+import sys
+from os import getcwd, unlink
+from platform import system
+from tempfile import NamedTemporaryFile
+from typing import Any, Callable, Dict, List, Type
+from PIL import Image
+
+import pandas as pd
+import streamlit as st
+from fastapi.encoders import jsonable_encoder
+from loguru import logger
+from pydantic import BaseModel, ValidationError, parse_obj_as
+
+from mkgui.base import Opyrator
+from mkgui.base.core import name_to_title
+from mkgui.base.ui import schema_utils
+from mkgui.base.ui.streamlit_utils import CUSTOM_STREAMLIT_CSS
+
+STREAMLIT_RUNNER_SNIPPET = """
+from mkgui.base.ui import render_streamlit_ui
+from mkgui.base import Opyrator
+
+import streamlit as st
+
+# TODO: Make it configurable
+# Page config can only be setup once
+st.set_page_config(
+ page_title="MockingBird",
+ page_icon="🧊",
+ layout="wide")
+
+with st.spinner("Loading MockingBird GUI. Please wait..."):
+ opyrator = Opyrator("{opyrator_path}")
+
+render_streamlit_ui(opyrator)
+"""
+
+
+def launch_ui(opyrator_path: str, port: int = 8501) -> None:
+ with NamedTemporaryFile(
+ suffix=".py", mode="w", encoding="utf-8", delete=False
+ ) as f:
+ f.write(STREAMLIT_RUNNER_SNIPPET.format(opyrator_path=opyrator_path))
+ f.seek(0)
+
+ # TODO: PYTHONPATH="$PYTHONPATH:/workspace/opyrator/src"
+ import subprocess
+
+ python_path = f'PYTHONPATH="$PYTHONPATH:{getcwd()}"'
+ if system() == "Windows":
+ python_path = f"set PYTHONPATH=%PYTHONPATH%;{getcwd()} &&"
+
+ subprocess.run(
+ f"""{python_path} "{sys.executable}" -m streamlit run --server.port={port} --server.headless=True --runner.magicEnabled=False --server.maxUploadSize=50 --browser.gatherUsageStats=False {f.name}""",
+ shell=True,
+ )
+
+ f.close()
+ unlink(f.name)
+
+
+def function_has_named_arg(func: Callable, parameter: str) -> bool:
+ try:
+ sig = inspect.signature(func)
+ for param in sig.parameters.values():
+ if param.name == "input":
+ return True
+ except Exception:
+ return False
+ return False
+
+
+def has_output_ui_renderer(data_item: BaseModel) -> bool:
+ return hasattr(data_item, "render_output_ui")
+
+
+def has_input_ui_renderer(input_class: Type[BaseModel]) -> bool:
+ return hasattr(input_class, "render_input_ui")
+
+
+def is_compatible_audio(mime_type: str) -> bool:
+ return mime_type in ["audio/mpeg", "audio/ogg", "audio/wav"]
+
+
+def is_compatible_image(mime_type: str) -> bool:
+ return mime_type in ["image/png", "image/jpeg"]
+
+
+def is_compatible_video(mime_type: str) -> bool:
+ return mime_type in ["video/mp4"]
+
+
+class InputUI:
+ def __init__(self, session_state, input_class: Type[BaseModel]):
+ self._session_state = session_state
+ self._input_class = input_class
+
+ self._schema_properties = input_class.schema(by_alias=True).get(
+ "properties", {}
+ )
+ self._schema_references = input_class.schema(by_alias=True).get(
+ "definitions", {}
+ )
+
+ # TODO: check if state has input data
+
+ def render_ui(self, streamlit_app_root) -> None:
+ if has_input_ui_renderer(self._input_class):
+ # The input model has a rendering function
+ # The rendering also returns the current state of input data
+ self._session_state.input_data = self._input_class.render_input_ui( # type: ignore
+ st, self._session_state.input_data
+ ).dict()
+ return
+
+ required_properties = self._input_class.schema(by_alias=True).get(
+ "required", []
+ )
+ print(self._schema_properties)
+ for property_key in self._schema_properties.keys():
+ property = self._schema_properties[property_key]
+
+ if not property.get("title"):
+ # Set property key as fallback title
+ property["title"] = name_to_title(property_key)
+
+ try:
+ self._store_value(
+ property_key,
+ self._render_property(streamlit_app_root, property_key, property),
+ )
+ except Exception as e:
+ print("Exception!", e)
+ pass
+
+ def _get_default_streamlit_input_kwargs(self, key: str, property: Dict) -> Dict:
+ streamlit_kwargs = {
+ "label": property.get("title"),
+ "key": key,
+ }
+
+ if property.get("description"):
+ streamlit_kwargs["help"] = property.get("description")
+ return streamlit_kwargs
+
+ def _store_value(self, key: str, value: Any) -> None:
+ data_element = self._session_state.input_data
+ key_elements = key.split(".")
+ for i, key_element in enumerate(key_elements):
+ if i == len(key_elements) - 1:
+ # add value to this element
+ data_element[key_element] = value
+ return
+ if key_element not in data_element:
+ data_element[key_element] = {}
+ data_element = data_element[key_element]
+
+ def _get_value(self, key: str) -> Any:
+ data_element = self._session_state.input_data
+ key_elements = key.split(".")
+ for i, key_element in enumerate(key_elements):
+ if i == len(key_elements) - 1:
+ # add value to this element
+ if key_element not in data_element:
+ return None
+ return data_element[key_element]
+ if key_element not in data_element:
+ data_element[key_element] = {}
+ data_element = data_element[key_element]
+ return None
+
+ def _render_single_datetime_input(
+ self, streamlit_app: st, key: str, property: Dict
+ ) -> Any:
+ streamlit_kwargs = self._get_default_streamlit_input_kwargs(key, property)
+
+ if property.get("format") == "time":
+ if property.get("default"):
+ try:
+ streamlit_kwargs["value"] = datetime.time.fromisoformat( # type: ignore
+ property.get("default")
+ )
+ except Exception:
+ pass
+ return streamlit_app.time_input(**streamlit_kwargs)
+ elif property.get("format") == "date":
+ if property.get("default"):
+ try:
+ streamlit_kwargs["value"] = datetime.date.fromisoformat( # type: ignore
+ property.get("default")
+ )
+ except Exception:
+ pass
+ return streamlit_app.date_input(**streamlit_kwargs)
+ elif property.get("format") == "date-time":
+ if property.get("default"):
+ try:
+ streamlit_kwargs["value"] = datetime.datetime.fromisoformat( # type: ignore
+ property.get("default")
+ )
+ except Exception:
+ pass
+ with streamlit_app.container():
+ streamlit_app.subheader(streamlit_kwargs.get("label"))
+ if streamlit_kwargs.get("description"):
+ streamlit_app.text(streamlit_kwargs.get("description"))
+ selected_date = None
+ selected_time = None
+ date_col, time_col = streamlit_app.columns(2)
+ with date_col:
+ date_kwargs = {"label": "Date", "key": key + "-date-input"}
+ if streamlit_kwargs.get("value"):
+ try:
+ date_kwargs["value"] = streamlit_kwargs.get( # type: ignore
+ "value"
+ ).date()
+ except Exception:
+ pass
+ selected_date = streamlit_app.date_input(**date_kwargs)
+
+ with time_col:
+ time_kwargs = {"label": "Time", "key": key + "-time-input"}
+ if streamlit_kwargs.get("value"):
+ try:
+ time_kwargs["value"] = streamlit_kwargs.get( # type: ignore
+ "value"
+ ).time()
+ except Exception:
+ pass
+ selected_time = streamlit_app.time_input(**time_kwargs)
+ return datetime.datetime.combine(selected_date, selected_time)
+ else:
+ streamlit_app.warning(
+ "Date format is not supported: " + str(property.get("format"))
+ )
+
+ def _render_single_file_input(
+ self, streamlit_app: st, key: str, property: Dict
+ ) -> Any:
+ streamlit_kwargs = self._get_default_streamlit_input_kwargs(key, property)
+ file_extension = None
+ if "mime_type" in property:
+ file_extension = mimetypes.guess_extension(property["mime_type"])
+
+ uploaded_file = streamlit_app.file_uploader(
+ **streamlit_kwargs, accept_multiple_files=False, type=file_extension
+ )
+ if uploaded_file is None:
+ return None
+
+ bytes = uploaded_file.getvalue()
+ if property.get("mime_type"):
+ if is_compatible_audio(property["mime_type"]):
+ # Show audio
+ streamlit_app.audio(bytes, format=property.get("mime_type"))
+ if is_compatible_image(property["mime_type"]):
+ # Show image
+ streamlit_app.image(bytes)
+ if is_compatible_video(property["mime_type"]):
+ # Show video
+ streamlit_app.video(bytes, format=property.get("mime_type"))
+ return bytes
+
+ def _render_single_string_input(
+ self, streamlit_app: st, key: str, property: Dict
+ ) -> Any:
+ streamlit_kwargs = self._get_default_streamlit_input_kwargs(key, property)
+
+ if property.get("default"):
+ streamlit_kwargs["value"] = property.get("default")
+ elif property.get("example"):
+ # TODO: also use example for other property types
+ # Use example as value if it is provided
+ streamlit_kwargs["value"] = property.get("example")
+
+ if property.get("maxLength") is not None:
+ streamlit_kwargs["max_chars"] = property.get("maxLength")
+
+ if (
+ property.get("format")
+ or (
+ property.get("maxLength") is not None
+ and int(property.get("maxLength")) < 140 # type: ignore
+ )
+ or property.get("writeOnly")
+ ):
+ # If any format is set, use single text input
+ # If max chars is set to less than 140, use single text input
+ # If write only -> password field
+ if property.get("writeOnly"):
+ streamlit_kwargs["type"] = "password"
+ return streamlit_app.text_input(**streamlit_kwargs)
+ else:
+ # Otherwise use multiline text area
+ return streamlit_app.text_area(**streamlit_kwargs)
+
+ def _render_multi_enum_input(
+ self, streamlit_app: st, key: str, property: Dict
+ ) -> Any:
+ streamlit_kwargs = self._get_default_streamlit_input_kwargs(key, property)
+ reference_item = schema_utils.resolve_reference(
+ property["items"]["$ref"], self._schema_references
+ )
+ # TODO: how to select defaults
+ return streamlit_app.multiselect(
+ **streamlit_kwargs, options=reference_item["enum"]
+ )
+
+ def _render_single_enum_input(
+ self, streamlit_app: st, key: str, property: Dict
+ ) -> Any:
+
+ streamlit_kwargs = self._get_default_streamlit_input_kwargs(key, property)
+ reference_item = schema_utils.get_single_reference_item(
+ property, self._schema_references
+ )
+
+ if property.get("default") is not None:
+ try:
+ streamlit_kwargs["index"] = reference_item["enum"].index(
+ property.get("default")
+ )
+ except Exception:
+ # Use default selection
+ pass
+
+ return streamlit_app.selectbox(
+ **streamlit_kwargs, options=reference_item["enum"]
+ )
+
+ def _render_single_dict_input(
+ self, streamlit_app: st, key: str, property: Dict
+ ) -> Any:
+
+ # Add title and subheader
+ streamlit_app.subheader(property.get("title"))
+ if property.get("description"):
+ streamlit_app.markdown(property.get("description"))
+
+ streamlit_app.markdown("---")
+
+ current_dict = self._get_value(key)
+ if not current_dict:
+ current_dict = {}
+
+ key_col, value_col = streamlit_app.columns(2)
+
+ with key_col:
+ updated_key = streamlit_app.text_input(
+ "Key", value="", key=key + "-new-key"
+ )
+
+ with value_col:
+ # TODO: also add boolean?
+ value_kwargs = {"label": "Value", "key": key + "-new-value"}
+ if property["additionalProperties"].get("type") == "integer":
+ value_kwargs["value"] = 0 # type: ignore
+ updated_value = streamlit_app.number_input(**value_kwargs)
+ elif property["additionalProperties"].get("type") == "number":
+ value_kwargs["value"] = 0.0 # type: ignore
+ value_kwargs["format"] = "%f"
+ updated_value = streamlit_app.number_input(**value_kwargs)
+ else:
+ value_kwargs["value"] = ""
+ updated_value = streamlit_app.text_input(**value_kwargs)
+
+ streamlit_app.markdown("---")
+
+ with streamlit_app.container():
+ clear_col, add_col = streamlit_app.columns([1, 2])
+
+ with clear_col:
+ if streamlit_app.button("Clear Items", key=key + "-clear-items"):
+ current_dict = {}
+
+ with add_col:
+ if (
+ streamlit_app.button("Add Item", key=key + "-add-item")
+ and updated_key
+ ):
+ current_dict[updated_key] = updated_value
+
+ streamlit_app.write(current_dict)
+
+ return current_dict
+
+ def _render_single_reference(
+ self, streamlit_app: st, key: str, property: Dict
+ ) -> Any:
+ reference_item = schema_utils.get_single_reference_item(
+ property, self._schema_references
+ )
+ return self._render_property(streamlit_app, key, reference_item)
+
+ def _render_multi_file_input(
+ self, streamlit_app: st, key: str, property: Dict
+ ) -> Any:
+ streamlit_kwargs = self._get_default_streamlit_input_kwargs(key, property)
+
+ file_extension = None
+ if "mime_type" in property:
+ file_extension = mimetypes.guess_extension(property["mime_type"])
+
+ uploaded_files = streamlit_app.file_uploader(
+ **streamlit_kwargs, accept_multiple_files=True, type=file_extension
+ )
+ uploaded_files_bytes = []
+ if uploaded_files:
+ for uploaded_file in uploaded_files:
+ uploaded_files_bytes.append(uploaded_file.read())
+ return uploaded_files_bytes
+
+ def _render_single_boolean_input(
+ self, streamlit_app: st, key: str, property: Dict
+ ) -> Any:
+ streamlit_kwargs = self._get_default_streamlit_input_kwargs(key, property)
+
+ if property.get("default"):
+ streamlit_kwargs["value"] = property.get("default")
+ return streamlit_app.checkbox(**streamlit_kwargs)
+
+ def _render_single_number_input(
+ self, streamlit_app: st, key: str, property: Dict
+ ) -> Any:
+ streamlit_kwargs = self._get_default_streamlit_input_kwargs(key, property)
+
+ number_transform = int
+ if property.get("type") == "number":
+ number_transform = float # type: ignore
+ streamlit_kwargs["format"] = "%f"
+
+ if "multipleOf" in property:
+ # Set stepcount based on multiple of parameter
+ streamlit_kwargs["step"] = number_transform(property["multipleOf"])
+ elif number_transform == int:
+ # Set step size to 1 as default
+ streamlit_kwargs["step"] = 1
+ elif number_transform == float:
+ # Set step size to 0.01 as default
+ # TODO: adapt to default value
+ streamlit_kwargs["step"] = 0.01
+
+ if "minimum" in property:
+ streamlit_kwargs["min_value"] = number_transform(property["minimum"])
+ if "exclusiveMinimum" in property:
+ streamlit_kwargs["min_value"] = number_transform(
+ property["exclusiveMinimum"] + streamlit_kwargs["step"]
+ )
+ if "maximum" in property:
+ streamlit_kwargs["max_value"] = number_transform(property["maximum"])
+
+ if "exclusiveMaximum" in property:
+ streamlit_kwargs["max_value"] = number_transform(
+ property["exclusiveMaximum"] - streamlit_kwargs["step"]
+ )
+
+ if property.get("default") is not None:
+ streamlit_kwargs["value"] = number_transform(property.get("default")) # type: ignore
+ else:
+ if "min_value" in streamlit_kwargs:
+ streamlit_kwargs["value"] = streamlit_kwargs["min_value"]
+ elif number_transform == int:
+ streamlit_kwargs["value"] = 0
+ else:
+ # Set default value to step
+ streamlit_kwargs["value"] = number_transform(streamlit_kwargs["step"])
+
+ if "min_value" in streamlit_kwargs and "max_value" in streamlit_kwargs:
+ # TODO: Only if less than X steps
+ return streamlit_app.slider(**streamlit_kwargs)
+ else:
+ return streamlit_app.number_input(**streamlit_kwargs)
+
+ def _render_object_input(self, streamlit_app: st, key: str, property: Dict) -> Any:
+ properties = property["properties"]
+ object_inputs = {}
+ for property_key in properties:
+ property = properties[property_key]
+ if not property.get("title"):
+ # Set property key as fallback title
+ property["title"] = name_to_title(property_key)
+ # construct full key based on key parts -> required later to get the value
+ full_key = key + "." + property_key
+ object_inputs[property_key] = self._render_property(
+ streamlit_app, full_key, property
+ )
+ return object_inputs
+
+ def _render_single_object_input(
+ self, streamlit_app: st, key: str, property: Dict
+ ) -> Any:
+ # Add title and subheader
+ title = property.get("title")
+ streamlit_app.subheader(title)
+ if property.get("description"):
+ streamlit_app.markdown(property.get("description"))
+
+ object_reference = schema_utils.get_single_reference_item(
+ property, self._schema_references
+ )
+ return self._render_object_input(streamlit_app, key, object_reference)
+
+ def _render_property_list_input(
+ self, streamlit_app: st, key: str, property: Dict
+ ) -> Any:
+
+ # Add title and subheader
+ streamlit_app.subheader(property.get("title"))
+ if property.get("description"):
+ streamlit_app.markdown(property.get("description"))
+
+ streamlit_app.markdown("---")
+
+ current_list = self._get_value(key)
+ if not current_list:
+ current_list = []
+
+ value_kwargs = {"label": "Value", "key": key + "-new-value"}
+ if property["items"]["type"] == "integer":
+ value_kwargs["value"] = 0 # type: ignore
+ new_value = streamlit_app.number_input(**value_kwargs)
+ elif property["items"]["type"] == "number":
+ value_kwargs["value"] = 0.0 # type: ignore
+ value_kwargs["format"] = "%f"
+ new_value = streamlit_app.number_input(**value_kwargs)
+ else:
+ value_kwargs["value"] = ""
+ new_value = streamlit_app.text_input(**value_kwargs)
+
+ streamlit_app.markdown("---")
+
+ with streamlit_app.container():
+ clear_col, add_col = streamlit_app.columns([1, 2])
+
+ with clear_col:
+ if streamlit_app.button("Clear Items", key=key + "-clear-items"):
+ current_list = []
+
+ with add_col:
+ if (
+ streamlit_app.button("Add Item", key=key + "-add-item")
+ and new_value is not None
+ ):
+ current_list.append(new_value)
+
+ streamlit_app.write(current_list)
+
+ return current_list
+
+ def _render_object_list_input(
+ self, streamlit_app: st, key: str, property: Dict
+ ) -> Any:
+
+ # TODO: support max_items, and min_items properties
+
+ # Add title and subheader
+ streamlit_app.subheader(property.get("title"))
+ if property.get("description"):
+ streamlit_app.markdown(property.get("description"))
+
+ streamlit_app.markdown("---")
+
+ current_list = self._get_value(key)
+ if not current_list:
+ current_list = []
+
+ object_reference = schema_utils.resolve_reference(
+ property["items"]["$ref"], self._schema_references
+ )
+ input_data = self._render_object_input(streamlit_app, key, object_reference)
+
+ streamlit_app.markdown("---")
+
+ with streamlit_app.container():
+ clear_col, add_col = streamlit_app.columns([1, 2])
+
+ with clear_col:
+ if streamlit_app.button("Clear Items", key=key + "-clear-items"):
+ current_list = []
+
+ with add_col:
+ if (
+ streamlit_app.button("Add Item", key=key + "-add-item")
+ and input_data
+ ):
+ current_list.append(input_data)
+
+ streamlit_app.write(current_list)
+ return current_list
+
+ def _render_property(self, streamlit_app: st, key: str, property: Dict) -> Any:
+ if schema_utils.is_single_enum_property(property, self._schema_references):
+ return self._render_single_enum_input(streamlit_app, key, property)
+
+ if schema_utils.is_multi_enum_property(property, self._schema_references):
+ return self._render_multi_enum_input(streamlit_app, key, property)
+
+ if schema_utils.is_single_file_property(property):
+ return self._render_single_file_input(streamlit_app, key, property)
+
+ if schema_utils.is_multi_file_property(property):
+ return self._render_multi_file_input(streamlit_app, key, property)
+
+ if schema_utils.is_single_datetime_property(property):
+ return self._render_single_datetime_input(streamlit_app, key, property)
+
+ if schema_utils.is_single_boolean_property(property):
+ return self._render_single_boolean_input(streamlit_app, key, property)
+
+ if schema_utils.is_single_dict_property(property):
+ return self._render_single_dict_input(streamlit_app, key, property)
+
+ if schema_utils.is_single_number_property(property):
+ return self._render_single_number_input(streamlit_app, key, property)
+
+ if schema_utils.is_single_string_property(property):
+ return self._render_single_string_input(streamlit_app, key, property)
+
+ if schema_utils.is_single_object(property, self._schema_references):
+ return self._render_single_object_input(streamlit_app, key, property)
+
+ if schema_utils.is_object_list_property(property, self._schema_references):
+ return self._render_object_list_input(streamlit_app, key, property)
+
+ if schema_utils.is_property_list(property):
+ return self._render_property_list_input(streamlit_app, key, property)
+
+ if schema_utils.is_single_reference(property):
+ return self._render_single_reference(streamlit_app, key, property)
+
+ streamlit_app.warning(
+ "The type of the following property is currently not supported: "
+ + str(property.get("title"))
+ )
+ raise Exception("Unsupported property")
+
+
+class OutputUI:
+ def __init__(self, output_data: Any, input_data: Any):
+ self._output_data = output_data
+ self._input_data = input_data
+
+ def render_ui(self, streamlit_app) -> None:
+ try:
+ if isinstance(self._output_data, BaseModel):
+ self._render_single_output(streamlit_app, self._output_data)
+ return
+ if type(self._output_data) == list:
+ self._render_list_output(streamlit_app, self._output_data)
+ return
+ except Exception as ex:
+ streamlit_app.exception(ex)
+ # Fallback to
+ streamlit_app.json(jsonable_encoder(self._output_data))
+
+ def _render_single_text_property(
+ self, streamlit: st, property_schema: Dict, value: Any
+ ) -> None:
+ # Add title and subheader
+ streamlit.subheader(property_schema.get("title"))
+ if property_schema.get("description"):
+ streamlit.markdown(property_schema.get("description"))
+ if value is None or value == "":
+ streamlit.info("No value returned!")
+ else:
+ streamlit.code(str(value), language="plain")
+
+ def _render_single_file_property(
+ self, streamlit: st, property_schema: Dict, value: Any
+ ) -> None:
+ # Add title and subheader
+ streamlit.subheader(property_schema.get("title"))
+ if property_schema.get("description"):
+ streamlit.markdown(property_schema.get("description"))
+ if value is None or value == "":
+ streamlit.info("No value returned!")
+ else:
+ # TODO: Detect if it is a FileContent instance
+ # TODO: detect if it is base64
+ file_extension = ""
+ if "mime_type" in property_schema:
+ mime_type = property_schema["mime_type"]
+ file_extension = mimetypes.guess_extension(mime_type) or ""
+
+ if is_compatible_audio(mime_type):
+ streamlit.audio(value.as_bytes(), format=mime_type)
+ return
+
+ if is_compatible_image(mime_type):
+ streamlit.image(value.as_bytes())
+ return
+
+ if is_compatible_video(mime_type):
+ streamlit.video(value.as_bytes(), format=mime_type)
+ return
+
+ filename = (
+ (property_schema["title"] + file_extension)
+ .lower()
+ .strip()
+ .replace(" ", "-")
+ )
+ streamlit.markdown(
+ f'',
+ unsafe_allow_html=True,
+ )
+
+ def _render_single_complex_property(
+ self, streamlit: st, property_schema: Dict, value: Any
+ ) -> None:
+ # Add title and subheader
+ streamlit.subheader(property_schema.get("title"))
+ if property_schema.get("description"):
+ streamlit.markdown(property_schema.get("description"))
+
+ streamlit.json(jsonable_encoder(value))
+
+ def _render_single_output(self, streamlit: st, output_data: BaseModel) -> None:
+ try:
+ if has_output_ui_renderer(output_data):
+ if function_has_named_arg(output_data.render_output_ui, "input"): # type: ignore
+ # render method also requests the input data
+ output_data.render_output_ui(streamlit, input=self._input_data) # type: ignore
+ else:
+ output_data.render_output_ui(streamlit) # type: ignore
+ return
+ except Exception:
+ # Use default auto-generation methods if the custom rendering throws an exception
+ logger.exception(
+ "Failed to execute custom render_output_ui function. Using auto-generation instead"
+ )
+
+ model_schema = output_data.schema(by_alias=False)
+ model_properties = model_schema.get("properties")
+ definitions = model_schema.get("definitions")
+
+ if model_properties:
+ for property_key in output_data.__dict__:
+ property_schema = model_properties.get(property_key)
+ if not property_schema.get("title"):
+ # Set property key as fallback title
+ property_schema["title"] = property_key
+
+ output_property_value = output_data.__dict__[property_key]
+
+ if has_output_ui_renderer(output_property_value):
+ output_property_value.render_output_ui(streamlit) # type: ignore
+ continue
+
+ if isinstance(output_property_value, BaseModel):
+ # Render output recursivly
+ streamlit.subheader(property_schema.get("title"))
+ if property_schema.get("description"):
+ streamlit.markdown(property_schema.get("description"))
+ self._render_single_output(streamlit, output_property_value)
+ continue
+
+ if property_schema:
+ if schema_utils.is_single_file_property(property_schema):
+ self._render_single_file_property(
+ streamlit, property_schema, output_property_value
+ )
+ continue
+
+ if (
+ schema_utils.is_single_string_property(property_schema)
+ or schema_utils.is_single_number_property(property_schema)
+ or schema_utils.is_single_datetime_property(property_schema)
+ or schema_utils.is_single_boolean_property(property_schema)
+ ):
+ self._render_single_text_property(
+ streamlit, property_schema, output_property_value
+ )
+ continue
+ if definitions and schema_utils.is_single_enum_property(
+ property_schema, definitions
+ ):
+ self._render_single_text_property(
+ streamlit, property_schema, output_property_value.value
+ )
+ continue
+
+ # TODO: render dict as table
+
+ self._render_single_complex_property(
+ streamlit, property_schema, output_property_value
+ )
+ return
+
+ def _render_list_output(self, streamlit: st, output_data: List) -> None:
+ try:
+ data_items: List = []
+ for data_item in output_data:
+ if has_output_ui_renderer(data_item):
+ # Render using the render function
+ data_item.render_output_ui(streamlit) # type: ignore
+ continue
+ data_items.append(data_item.dict())
+ # Try to show as dataframe
+ streamlit.table(pd.DataFrame(data_items))
+ except Exception:
+ # Fallback to
+ streamlit.json(jsonable_encoder(output_data))
+
+
+def render_streamlit_ui(opyrator: Opyrator) -> None:
+ title = opyrator.name
+
+ # init
+ session_state = st.session_state
+ session_state.input_data = {}
+
+ col1, col2, _ = st.columns(3)
+ col2.title(title)
+ image = Image.open('.\\mkgui\\static\\mb.png')
+ col1.image(image)
+
+ # Add custom css settings
+ st.markdown(f"", unsafe_allow_html=True)
+
+ if opyrator.description:
+ st.markdown(opyrator.description)
+
+ left, right = st.columns(2)
+ InputUI(session_state=session_state, input_class=opyrator.input_type).render_ui(left)
+
+
+ with left:
+ execute_selected = st.button("Execute")
+ if execute_selected:
+ with st.spinner("Executing operation. Please wait..."):
+ try:
+ input_data_obj = parse_obj_as(
+ opyrator.input_type, session_state.input_data
+ )
+ session_state.output_data = opyrator(input=input_data_obj)
+ session_state.latest_operation_input = input_data_obj # should this really be saved as additional session object?
+ except ValidationError as ex:
+ st.error(ex)
+ else:
+ # st.success("Operation executed successfully.")
+ pass
+ if st.button("Clear"):
+ # Clear all state
+ for key in st.session_state.keys():
+ del st.session_state[key]
+ session_state.input_data = {}
+ st.experimental_rerun()
+
+
+ if 'output_data' in session_state:
+ OutputUI(
+ session_state.output_data, session_state.latest_operation_input
+ ).render_ui(right)
+
+ # st.markdown("---")
+
diff --git a/mkgui/base/ui/streamlit_utils.py b/mkgui/base/ui/streamlit_utils.py
new file mode 100644
index 0000000..e8f15c6
--- /dev/null
+++ b/mkgui/base/ui/streamlit_utils.py
@@ -0,0 +1,10 @@
+CUSTOM_STREAMLIT_CSS = """
+div[data-testid="stBlock"] button {
+ width: 100% !important;
+ margin-bottom: 20px !important;
+ border-color: #bfbfbf !important;
+}
+pre code {
+ white-space: pre-wrap;
+}
+"""
diff --git a/mkgui/static/mb.png b/mkgui/static/mb.png
new file mode 100644
index 0000000..abd804c
Binary files /dev/null and b/mkgui/static/mb.png differ
diff --git a/requirements.txt b/requirements.txt
index 21becf4..84416e2 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -21,5 +21,4 @@ flask_cors==3.0.10
gevent==21.8.0
flask_restx
tensorboard
-opyrator
-streamlit==1.3.1
\ No newline at end of file
+streamlit==1.8.0
\ No newline at end of file