【深度学习】四种天气分类 模版函数 从0到1手敲版本

news/2024/7/23 9:57:22 标签: 深度学习, 分类, 人工智能

引入该引入的库

import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import torch.nn.functional as F
import torchvision
import torch.optim as optim
%matplotlib inline
import os
import shutil
import glob
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"

注意:os.environ[“KMP_DUPLICATE_LIB_OK”]=“TRUE” 必须要引入否则用plt出错

数据集整理

img_dir = r"F:\播放器\1、pytorch全套入门与实战项目\课程资料\参考代码和部分数据集\参考代码\参考代码\29-42节参考代码和数据集\四种天气图片数据集\dataset2"
base_dir = r"./dataset/4weather"

img_list = glob.glob(img_dir+"/*.*")
test_dir = "test"
train_dir = "train"
species = ["cloudy","rain","shine","sunrise"]
for idx,img_path in enumerate(img_list):
    _,img_name = os.path.split(img_path)
    if idx%5==0:

        for specie in species:
            if img_path.find(specie) > -1:
                dst_dir = os.path.join(test_dir,specie)
                os.makedirs(dst_dir,exist_ok=True)
                dst_path = os.path.join(dst_dir,img_name)
    else:
        
        for specie in species:
            if img_path.find(specie) > -1:
                dst_dir = os.path.join(train_dir,specie)
                os.makedirs(dst_dir,exist_ok=True)
                dst_path = os.path.join(dst_dir,img_name)
    shutil.copy(img_path,dst_path)

生成测试和训练的文件夹,
目录结构如下:
在这里插入图片描述
rain 下面就是图片了
在这里插入图片描述

构建ds和dl

from torchvision import transforms
transform = transforms.Compose([transforms.Resize((96,96)),transforms.ToTensor(),transforms.Normalize(mean=[0.5,0.5,0.5],std=[0.5,0.5,0.5])])
train_ds=torchvision.datasets.ImageFolder(train_dir,transform)
test_ds = torchvision.datasets.ImageFolder(train_dir,transform)

在这里插入图片描述
在这里插入图片描述
一张图片效果,这是rain图片 这里需要转换维度,把channel放到最后。同时把数据拉到0-1之间,原本std 和mean 【0.5,0,5】数据在-0.5~0.5之间
在这里插入图片描述
类的映射
在这里插入图片描述

plt.figure(figsize=(12, 8))
for i, (img, label) in enumerate(zip(imgs[:6], labels[:6])):
    img = (img.permute(1, 2, 0).numpy() + 1)/2
    plt.subplot(2, 3, i+1)
    plt.title(id_to_class.get(label.item()))
    plt.imshow(img)

这个方法要学会
在这里插入图片描述

定义网络

class Net(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(3,16,3)
        self.conv2 = nn.Conv2d(16,32,3)
        self.conv3 = nn.Conv2d(32,64,3)
        self.pool = nn.MaxPool2d(2,2)
        self.dropout = nn.Dropout(0.3)
        self.fc1 = nn.Linear(64*10*10,1024)
        self.fc2 = nn.Linear(1024,4)
    def forward(self,x):
        x = F.relu(self.conv1(x))
        x = self.pool(x)
        x = F.relu(self.conv2(x))
        x = self.pool(x)
        x = F.relu(self.conv3(x))
        x = self.pool(x)
        x = self.dropout(x)
        # print(x.size()) 这里是可以计算出来的,需要掌握计算方法
        x = x.view(-1,64*10*10)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        return self.fc2(x)
model = Net()        
preds = model(imgs)
preds.shape, preds

在这里插入图片描述
定义损失函数和优化函数:

loss_fn = nn.CrossEntropyLoss()
optim = torch.optim.Adam(model.parameters(),lr=0.001)

定义网络

def fit(epoch, model, trainloader, testloader):
    correct = 0
    total = 0
    running_loss = 0
    for x, y in trainloader:
        if torch.cuda.is_available():
            x, y = x.to('cuda'), y.to('cuda')
        y_pred = model(x)
        loss = loss_fn(y_pred, y)
        optim.zero_grad()
        loss.backward()
        optim.step()
        with torch.no_grad():
            y_pred = torch.argmax(y_pred, dim=1)
            correct += (y_pred == y).sum().item()
            total += y.size(0)
            running_loss += loss.item()
        
    epoch_loss = running_loss / len(trainloader.dataset)
    epoch_acc = correct / total
        
        
    test_correct = 0
    test_total = 0
    test_running_loss = 0 
    
    with torch.no_grad():
        for x, y in testloader:
            if torch.cuda.is_available():
                x, y = x.to('cuda'), y.to('cuda')
            y_pred = model(x)
            loss = loss_fn(y_pred, y)
            y_pred = torch.argmax(y_pred, dim=1)
            test_correct += (y_pred == y).sum().item()
            test_total += y.size(0)
            test_running_loss += loss.item()
    
    epoch_test_loss = test_running_loss / len(testloader.dataset)
    epoch_test_acc = test_correct / test_total
    
        
    print('epoch: ', epoch, 
          'loss: ', round(epoch_loss, 3),
          'accuracy:', round(epoch_acc, 3),
          'test_loss: ', round(epoch_test_loss, 3),
          'test_accuracy:', round(epoch_test_acc, 3)
             )
        
    return epoch_loss, epoch_acc, epoch_test_loss, epoch_test_acc

训练:

epochs = 30
train_loss = []
train_acc = []
test_loss = []
test_acc = []

for epoch in range(epochs):
    epoch_loss, epoch_acc, epoch_test_loss, epoch_test_acc = fit(epoch,
                                                                 model,
                                                                 train_dl,
                                                                 test_dl)
    train_loss.append(epoch_loss)
    train_acc.append(epoch_acc)
    test_loss.append(epoch_test_loss)
    test_acc.append(epoch_test_acc)
epoch:  0 loss:  0.043 accuracy: 0.714 test_loss:  0.029 test_accuracy: 0.809
epoch:  1 loss:  0.03 accuracy: 0.807 test_loss:  0.023 test_accuracy: 0.867
epoch:  2 loss:  0.024 accuracy: 0.857 test_loss:  0.018 test_accuracy: 0.888
epoch:  3 loss:  0.021 accuracy: 0.869 test_loss:  0.017 test_accuracy: 0.894
epoch:  4 loss:  0.018 accuracy: 0.886 test_loss:  0.014 test_accuracy: 0.921
epoch:  5 loss:  0.017 accuracy: 0.897 test_loss:  0.022 test_accuracy: 0.869
epoch:  6 loss:  0.013 accuracy: 0.923 test_loss:  0.008 test_accuracy: 0.944
epoch:  7 loss:  0.009 accuracy: 0.947 test_loss:  0.011 test_accuracy: 0.924
epoch:  8 loss:  0.006 accuracy: 0.966 test_loss:  0.004 test_accuracy: 0.988
epoch:  9 loss:  0.004 accuracy: 0.979 test_loss:  0.002 test_accuracy: 0.998
epoch:  10 loss:  0.004 accuracy: 0.979 test_loss:  0.005 test_accuracy: 0.966

在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
比较重要的点,
1.分类的数据集布局要记住
2.图片经过conv2 多次后的值要会算 todo
3.图片展示的方法要会


http://www.niftyadmin.cn/n/5447298.html

相关文章

Python 打包输出exe

Python 打包输出exe python 中可以使用第三方库pyinstaller 来将Python代码打包成可执行程序exe. pyinstaller 会将python解释器,依赖的库,以及代码打包成一个单独的可执行文件.这样即使没有安装python 也可以运行程序. 1: 安装pyinstaller pip install pyinstaller2: 打包…

2024最新华为OD机试试题库全 -【虚拟理财游戏】- C卷

1. 🌈题目详情 1.1 ⚠️题目 在一款虚拟游戏中生活,你必须进行投资以增强在虚拟游戏中的资产以免被淘汰出局。 现有一家Bank,它提供有若干理财产品 m 个,风险及投资回报不同,你有 N(元)进行投资,能接收的总风险值为X。 你要在可接受范围内选择最优的投资方式获得最…

阿里云云效-项目管理快速入门

目录 什么是项目 创建第一个项目 开始项目内协作 开启敏捷需求管理 什么是项目 项目,是围绕某一特定目标(如产品交付或服务),组织相应人力、物力资源进行的临时性工作项目具有非常强的计划性,有确定的开始日期和结…

勒索病毒变种.360、.halo勒索病毒来袭,如何恢复受感染的数据?

引言: 随着信息技术的飞速发展,网络安全问题日益凸显。近年来,勒索病毒成为了网络安全领域的一大难题,其中.360、.halo勒索病毒更是备受关注。本文将探讨.360、.halo勒索病毒的危害及应对策略,帮助大家更好地应对这一…

I2C系列(三):软件模拟I2C读写24C04

一.目标 PC 端的串口调试软件通过 RS-485 与单片机通信,控制单片机利用软件模拟 I2C 总线对 EEPROM(24C04) 进行任意读写。 二.RS-485简述 在工业控制领域,传输距离越长,要求抗干扰能力也越强。由于 RS-232 无法消除…

用cavans给图片加水印

原图片 index.html <html><head><meta charset"utf-8"><meta http-equiv"X-UA-Compatible" content"IEedge"><title>images</title><script src"./vue.global.js"></script> </h…

clang-query 的编译安装与使用示例

1&#xff0c;clang query 概述 作用&#xff1a; 检查一个程序源码的抽象语法树&#xff0c;测试 AST 匹配器&#xff1b; 帮助检查哪些 AST 节点与指定的 AST 匹配器相匹配&#xff1b; 2&#xff0c;clang-query 安装 准备&#xff1a; git clone --recursive https://git…

creator-webview与Android交互

title: creator-webview与Android交互 categories: Cocos2dx tags: [cocos2dx, creator, webview, 交互] date: 2024-03-23 13:17:20 comments: false mathjax: true toc: true creator-webview与Android交互 前篇 Android&#xff1a;你要的WebView与 JS 交互方式 都在这里了…