マサムネの部屋

Plotly + google colabでインタラクティブなグラフデビューする

Plotlyが数年前から気になっていたんですが、遂に使いたい場面に出くわしたので、少し勉強しました。そこで、使い方を共有したいと思います。1
Djangoで動かすコードがgithubにあるので、試したい人は使ってみてください。

GitHub - msamunetogetoge/plotly_ex: start plotly in python
start plotly in python. Contribute to msamunetogetoge/plotly_ex development by creating an account on GitHub.

Djangoで動かす記事は、noteで公開してます。
この記事では、Google Colabを使ってグラフを描く方法を紹介しています。
記事の最後に記事全体で使ったコードをgistで埋め込んでいるので、使ってみてください。

スポンサーリンク

Plotlyとは?

Plotlyは、javascript やpythonで利用できるオープンソースのグラフ描画ライブラリです。テキストやコミュニティが充実していて、英語に抵抗が無ければ知りたいことは大体解決できます。
Plotlyの特徴は、何といっても描かれたグラフを感覚的に操作してほしい情報を得られるという点です。
グラフに色々な線を描いた時に、クリック一つで線を消したり出現させたりは勿論、詳しく見たい部分を拡大したり、ブラウザ上で画像を保存したり出来ます。

Plotlyのインストール

pythonで使う場合は、pipやcondaでインストール出来ます。

pip install plotly

最速でplotlyを試す方法は、pythonファイルに以下を書き込んで実行する事です。

import plotly.graph_objects as go
fig = go.Figure(data=go.Bar(y=[2, 3, 1]))
fig.write_html('first_figure.html', auto_open=True)

すると、first_figure.htmlというファイルが生成され、ブラウザでグラフが表示されます。

first_figure.html

colabだと以下のように出来る

グラフの描き方

Plotlyでは、グラフを描くためのインスタンスを生成し、そこに属性を追加する事で、グラフの種類やレイアウトを設定します。2
この記事の中では、plotlyが配布しているデータを使います。以下のようなデータです。

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 506 entries, 0 to 505
Data columns (total 11 columns):
#   Column         Non-Null Count  Dtype
---  ------         --------------  -----
0   Date           506 non-null    object
1   AAPL.Open      506 non-null    float64
2   AAPL.High      506 non-null    float64
3   AAPL.Low       506 non-null    float64
4   AAPL.Close     506 non-null    float64
5   AAPL.Volume    506 non-null    int64
6   AAPL.Adjusted  506 non-null    float64
7   dn             506 non-null    float64
8   mavg           506 non-null    float64
9   up             506 non-null    float64
10  direction      506 non-null    object
dtypes: float64(8), int64(1), object(2)
memory usage: 43.6+ KB

初めに、キャンドルスティックグラフを描いてみます。
キャンドルスティックグラフとは、決められた時間1次元のデータを収集し、その時間内の初めのデータ、最後のデータ、最大値、最小値から一本のローソク上の図形を作り、並べていくグラフです。

最初の値<最後の値の時は緑色、逆の時は赤色のローソクになります。ファイナンスの世界で良く使われるグラフのようです。

このグラフをgoogle colabで描く為のコードは次のようになります。

import plotly
import plotly.graph_objs as go
from plotly.subplots import make_subplots
import pandas as pd
# Google Colab. やJupyter Lab.でプロットするためには,以下を実行する.
import plotly.io as pio
pio.renderers.default = "colab"

df = pd.read_csv(
    'https://raw.githubusercontent.com/plotly/datasets/master/finance-charts-apple.csv')

fig = go.Figure(data=[go.Candlestick(x=df['Date'],
                                    open=df['AAPL.Open'],
                                    high=df['AAPL.High'],
                                    low=df['AAPL.Low'],
                                    close=df['AAPL.Close'])])

fig.show()
キャンドルスティックグラフ

コードを少しだけ解説します。

fig = go.Figure(data=[go.Candlestick(x=df['Date'],
                                    open=df['AAPL.Open'],
                                    high=df['AAPL.High'],
                                    low=df['AAPL.Low'],
                                    close=df['AAPL.Close'])])

go.Figure()でインスタンスを作成し、data=[]にグラフオブジェクトを格納する事で、グラフを描けます。基本的には、go.somegraph(x=,y=)の形でグラフオブジェクトが作られます。34

グラフを描く時は、len(x)=len(y)が必要なので、一つのデータフレームを使ってグラフを描いた方が安全です。go.Figure()の詳しい説明は公式ドキュメントを読んでください。
取りあえずは、グラフオブジェクトを生成し、fig.show()すればグラフが見れる、という理解で良いと思います。

複数枚のグラフを一枚に描く

グラフは色々なモノを一枚にまとめたいものです。その方法を解説します。
グラフオブジェクトに、.add_trace()とする事で、グラフを何枚でも1枚にまとめる事が出来ます。例えば、先ほどのグラフに、移動平均線2本を足すには以下のようにします。

df = pd.read_csv(
    'https://raw.githubusercontent.com/plotly/datasets/master/finance-charts-apple.csv')
close = df['AAPL.Close']

fig = go.Figure(data=[go.Candlestick(x=df['Date'],
                                    open=df['AAPL.Open'],
                                    high=df['AAPL.High'],
                                    low=df['AAPL.Low'],
                                    close=df['AAPL.Close'])])
df["sma1"] = close.rolling(window=7).mean()
df["sma2"] = close.rolling(window=14).mean()
fig.add_trace(
    go.Scatter(
        x=df['Date'],
        y=df["sma1"],
        name="sma(7)",
        line=dict(
            color='darkorange',
            width=1)))
fig.add_trace(
    go.Scatter(
        x=df['Date'],
        y=df["sma2"],
        name="sma(14)",
        line=dict(
            color='tomato',
            width=1)))

fig.show()
candlestick with sma

コードを少し解説します。

df["sma1"] = close.rolling(window=7).mean()
df["sma2"] = close.rolling(window=14).mean()

df[“sma1”]には、7期の移動平均線が格納されています。n期の移動平均線は、過去n期分の平均値を集めたデータです。
0~n-2番目の要素はNoneで、n-1番目から、

$$\begin{eqnarray}
sma(n)[n-1] = \frac{y_0 + \cdots + y_{n-1} }{n}
\end{eqnarray}$$

のように、値が入っています。

次に、グラフオブジェクトを追加するコードです。

fig.add_trace(
    go.Scatter(
        x=df['Date'],
        y=df["sma1"],
        name="sma(7)",
        line=dict(
            color='darkorange',
            width=1)))

fig.add_trace()で一枚に描くグラフを追加します。上のコードでは、scatter を追加しています。
name= で何か指定すると、グラフの横に表示される名前を指定できます。また、lime=dict()でオプションを指定すると、曲線を描いてくれます。colorは、cssで使える色の名前が使えます。
scatter plotの詳細は、公式ドキュメントを読んでください。

グラフに目印を付ける

例えば、キャンドルスティックに自分で株を売買した情報を表示したいとします。
go.scatterのオプションを弄ると実現できます。

df = pd.read_csv(
    'https://raw.githubusercontent.com/plotly/datasets/master/finance-charts-apple.csv')

close = df['AAPL.Close']

fig = go.Figure(data=[go.Candlestick(x=df['Date'],
                                    open=df['AAPL.Open'],
                                    high=df['AAPL.High'],
                                    low=df['AAPL.Low'],
                                    close=df['AAPL.Close'])])
df["sma1"] = close.rolling(window=7).mean()
df["sma2"] = close.rolling(window=14).mean()
fig.add_trace(
    go.Scatter(
        x=df['Date'],
        y=df["sma1"],
        name="sma(7)",
        line=dict(
            color='darkorange',
            width=1)))
fig.add_trace(
    go.Scatter(
        x=df['Date'],
        y=df["sma2"],
        name="sma(14)",
        line=dict(
            color='tomato',
            width=1)))

event = [{"side": "BUY",
            "price": df['AAPL.Close'][0],
            "size":1000},
        {"side": "SELL",
            "price": df['AAPL.Close'][100],
            "size":1000},
        {"side": "BUY",
            "price": df['AAPL.Close'][300],
            "size":1000}]
fig.add_trace(go.Scatter(x=df["Date"][[0, 100, 300]], y=df["AAPL.Close"][[0, 100, 300]], name="orders", mode="markers",
                        text=event, textposition="bottom left", textfont=dict(
    family="sans serif",
    size=20,
    color="black"),
    marker=dict(
    color='maroon',
    size=6,)
   ))

fig.show()
candlestick with markers


グラフ上の赤い点にマウスカーソルを当てると、text=で指定した情報が表示されます。

コードを少し解説します。
初めに、リスト形式で、表示したい情報を書いています。

event = [{"side": "BUY",
             "price": df['AAPL.Close'][0],
             "size":1000},
            {"side": "SELL",
             "price": df['AAPL.Close'][100],
             "size":1000},
            {"side": "BUY",
             "price": df['AAPL.Close'][300],
             "size":1000}]

次に,fig.add_trace()でマーカー用のグラフを追加しています。

   fig.add_trace(go.Scatter(x=df["Date"][[0, 100, 300]], 
       y=df["AAPL.Close"][[0, 100, 300]], 
       name="orders", mode="markers",
       text=event, textposition="bottom left", textfont=dict(
       family="sans serif",
       size=20,
       color="black"),
       marker=dict(
       color='maroon',
       size=6,)
   ))

go.scatter()で、scatter plotを作るのですが、点を打つ場所の指定には、前に使ったデータフレームを使っています。
mode=”markers”と指定する事で、グラフ上に丸印を打つことが出来ます。
マーカーの書式については、textfont で指定します。
また、mode=””には、lineやline+text, markers+textがあり、+textをすると、text=で指定したデータが、直接グラフ上に表示されます。
何処に文字を表示するかをtextposition=””で指定します。5
詳しくは公式のドキュメントを読んでください。

複数のy軸を使用する

1枚のグラフで、y軸を二つ用意したい事があります。例えば、キャンドルスティックでは、出来高を同時に表示させたいと思ったりします。
そんな時は、グラフオブジェクトに対して、secondary_y =True を指定します。
例えば、次のように書きます。



df = pd.read_csv(
       'https://raw.githubusercontent.com/plotly/datasets/master/finance-charts-apple.csv')

close = df['AAPL.Close']

fig = make_subplots(specs=[[{"secondary_y": True}]])

fig.add_trace(go.Bar(x=df['Date'],
                    y=df['AAPL.Volume'],
                    name="Volume"), secondary_y=False)

fig.add_trace(go.Candlestick(x=df['Date'],
                            open=df['AAPL.Open'],
                            high=df['AAPL.High'],
                            low=df['AAPL.Low'],
                            close=df['AAPL.Close']),
                secondary_y=True)
df["sma1"] = close.rolling(window=7).mean()
df["sma2"] = close.rolling(window=14).mean()
fig.add_trace(
    go.Scatter(
        x=df['Date'],
        y=df["sma1"],
        name="sma(7)",
        line=dict(
            color='darkorange',
            width=1)))
fig.add_trace(
    go.Scatter(
        x=df['Date'],
        y=df["sma2"],
        name="sma(14)",
        line=dict(
            color='tomato',
            width=1)),
    secondary_y=True)

event = [{"side": "BUY",
            "price": df['AAPL.Close'][0],
            "size":1000},
        {"side": "SELL",
            "price": df['AAPL.Close'][100],
            "size":1000},
        {"side": "BUY",
            "price": df['AAPL.Close'][300],
            "size":1000}]
fig.add_trace(
    go.Scatter(x=df["Date"][[0, 100, 300]], y=df["AAPL.Close"][[0, 100, 300]],
                name="orders", mode="markers",
                text=event, textposition="bottom left", textfont=dict(
        family="sans serif",
        size=20,
        color="black"),
        marker=dict(
        color='maroon',
        size=6,)
    ),
    secondary_y=True)

fig.show()
candle stick with volumes

左側にもy軸が表れたのが分かると思います。コードを少し詳しく見ます。

fig = make_subplots(specs=[[{"secondary_y": True}]])

 fig.add_trace(go.Bar(x=df['Date'],
                        y=df['AAPL.Volume'],
                        name="Volume"), secondary_y=False)

 fig.add_trace(go.Candlestick(x=df['Date'],
                                open=df['AAPL.Open'],
                                high=df['AAPL.High'],
                                low=df['AAPL.Low'],
                                close=df['AAPL.Close']),
                 secondary_y=True)

先ほどまでのグラフは、fig=go.Figure()と始まっていましたが、今回はfig= make_subplots()で始まります。
描かれるグラフに何か細工したい時は、make_subplots()で始める事が多いです。

fig = make_subplots(specs=[[{"secondary_y": True}]])

と指定する事で、2つのy軸を使用するグラフを描くと宣言しています。
この後に、

fig.add_trace(go.Bar(), secondary_y=False)
fig.add_trace(go.CandleStick(), secondary_y=Trace)

として、2つのy軸を使用するグラフをそれぞれ描いています。
add_trace()で、secondary_y=Trueとすると、右側のy軸を基準にしたグラフだと認識されます。
secondary_y =False だと左側のy軸基準になります。
もっと詳しい話は、公式ドキュメントを読んでください。

グラフを縦(横)に重ねる

グラフを縦や横に重ねて、セットにしたい事があると思います。plotlyでそうすると、スライダーを共通に出来たりして、欲しい情報に素早くアクセスできるようになります。

これは、make_subplotsに書き加える事で実現できます。matplotlibと同じ感じなのでとっつきやすいかもしれません。
縦にグラフを2枚置きたい時は、例えば次のようにします。

df = pd.read_csv(
    'https://raw.githubusercontent.com/plotly/datasets/master/finance-charts-apple.csv')

close = df['AAPL.Close']

fig = make_subplots(rows=2,cols=1, shared_xaxes=True, row_heights=[0.6, 0.4],
                    specs=[[{"secondary_y": True}],
                            [{}]
                            ]
                    )

fig.add_trace(go.Scatter(x=df['Date'], y=df["AAPL.Adjusted"], name="Adjusted"),
                row=2, col=1,
                )

fig.add_trace(go.Bar(x=df['Date'],
                    y=df['AAPL.Volume'],
                    name="Volume"), row=1, col=1, secondary_y=False)

fig.add_trace(go.Candlestick(x=df['Date'],
                            open=df['AAPL.Open'],
                            high=df['AAPL.High'],
                            low=df['AAPL.Low'],
                            close=df['AAPL.Close']
                            ),
                row=1, col=1,
                secondary_y=True)
df["sma1"] = close.rolling(window=7).mean()
df["sma2"] = close.rolling(window=14).mean()
fig.add_trace(
    go.Scatter(
        x=df['Date'],
        y=df["sma1"],
        name="sma(7)",
        line=dict(
            color='darkorange',
            width=1)
    ),
    row=1, col=1, secondary_y=True)
fig.add_trace(
    go.Scatter(
        x=df['Date'],
        y=df["sma2"],
        name="sma(14)",
        line=dict(
            color='tomato',
            width=1)
    ),
    row=1, col=1, secondary_y=True)

event = [{"side": "BUY",
            "price": df['AAPL.Close'][0],
            "size":1000},
        {"side": "SELL",
            "price": df['AAPL.Close'][100],
            "size":1000},
        {"side": "BUY",
            "price": df['AAPL.Close'][300],
            "size":1000}]
fig.add_trace(
    go.Scatter(x=df["Date"][[0, 100, 300]], y=df["AAPL.Close"][[0, 100, 300]],
                name="orders", mode="markers",
                text=event, textposition="bottom left", textfont=dict(
        family="sans serif",
        size=20,
        color="black"),
        marker=dict(
        color='maroon',
        size=6,)
    ),
    row=1, col=1, secondary_y=True)

fig.show()
two rows graph

グラフが縦に2枚表示されます。
真ん中のスライドバーを弄ると、上下が連動して動きます。

zoom two rows graph

コードを見ましょう。変更したのは、make_subplots()や、add_trace()のオプションに行や列を識別する為のcol=, row=を入れた部分です。

  fig = make_subplots(rows=2, cols=1, shared_xaxes=True, row_heights=[0.6, 0.4],
                       specs=[[{"secondary_y": True}],
                              [{}]
                              ]
                       )

make_subplotsの中に設定を書き込んでいます。
rows=2,cols=1 は縦に二枚、横に一枚グラフを描く事を示しています。
shared_xaxes=Trueは、全てのグラフでx軸を共有するという事です。
 row_heights=[0.6, 0.4]で、それぞれのグラフの高さを比率で決めています。もしもrows=3としていたら、row_heights=[0.6, 0.2, 0.2]などとします。
specs=[[],[] ]の部分は、どのグラフに設定を適用するか、を決めています。
rows=2,cols=2とかにすると、2×2行列のように設定を書かなくてはいけません。
次に、グラフを作成する部分を見ます。

  fig.add_trace(go.Scatter(x=df['Date'], y=df["AAPL.Adjusted"], name="Adjusted"),
                 row=2, col=1,
                 )

   fig.add_trace(go.Bar(x=df['Date'],
                        y=df['AAPL.Volume'],
                        name="Volume"), row=1, col=1, secondary_y=False)

このコードは、大雑把には次のようになっています。

 fig.add_trace(go.Scatter(),row=2, col=1)
 fig.add_trace(go.Bar(), row=1, col=1, secondary_y=False)

make_subpliotsで、rowsやcolsで2以上を指定した時には、add_trace()で、それが何処のグラフか明示しなくてはなりません。
この時、(row,col)の場所のグラフの設定とmake_subplotsで指定した設定が合わないとエラーが起こります。
今回で言うと、(rows,col)=(1,1)ではsecondary_y を使うといっているので、その設定を書いています。もちろん、設定していないオプションを使おうとするとエラーが出ます。
今回の例では、

fig.add_trace(go.Scatter(),row=2, col=1,secondary_y=False )

とするとエラーが出ます。
他にも色々設定がありますが、詳しくは公式ドキュメントを読んでください。

この記事で使ったコード

google colabで動かせるコード全体を置いておきます。

まとめ

  1. plotlyのグラフが上手く記事に埋め込めなかったのでグラフは画像です。
  2. matplotlibと同じ感じです。
  3. キャンドルスティックを描くには先ほど説明した複数の値が必要なので、y= の部分が大量にあります。
  4. somegraph の部分には、scatter とかが入ります。
  5. 今回のコードのtextposition=”bottom left” には、特に意味がありませんが。